Jax's auto-diff
How derivatives are traced in JAX
Derivative tracing in jax
This guide provides a low-level look into how derivatives are traced in jax. In particular we will cover…
- What Jacobian-vector-products (JVP) and vector-Jacobian-products (VJP) are and why they form the basis of jax’s automatic differentiation pipeline
- How auto-diff tracing works and what distinguishes it from jax’s jit-tracing
- What residuals are and how they sometimes blow up memory usage during backwards-propagation
- How and why to implement custom JVP and VJP rules
- How and why VJP rules for problems with certain symmetries may be replaced by simpler JVPs
Prerequisites
I assume that you have already used jax for a bit and maybe also tried out jax.jit and jax.grad here and there. If you didn’t, it may be a good idea to play around a bit with jax and then come back when you really want to understand in more detail how auto-diff works.
You can find the notebook that this guide is based on here.
Forward-mode differentiation
Consider a nested function of the form
\[x_n = f_{n}(... f_3(f_2(f_1(x_0))))\]We might write its evaluation as a series of evaluation steps
\[x_0 \xrightarrow{f_1} x_1 \xrightarrow{f_2} x_2 \xrightarrow{f_3} ... \xrightarrow{f_{n}} x_n\]We are interested in taking derivatives of the final values with respect to the initial values. For now, let us assume that (the vector) $x_0$ may be written as a function of a single scalar value $\alpha$.
In that case we can write the final derivative in terms of the initial one: \(\begin{align} \frac{\partial x_n}{\partial \alpha} &= ? \\ \dot{x}_0 &:= \frac{\partial x_0}{\partial \alpha} \\ \dot{x}_1 &:= \frac{\partial x_1}{\partial \alpha} = \frac{\partial x_1}{\partial x_0} \frac{\partial x_0}{\partial \alpha} = \frac{\partial x_1}{\partial x_0} \dot{x}_0 \\ \dot{x}_n &:= \frac{\partial x_n}{\partial \alpha} = \frac{\partial x_n}{\partial x_{n-1}} \dot{x}_{n-1}\\ \end{align}\) Here we multiply the Jacobian $J_n(x_{n-1}) = \frac{\partial x_n}{\partial x_{n-1}} $ with a tangent vector \(\dot{x}_{n-1}\) . The tangent vector has the same shape as $x_{n-1}$. We call these evaluation steps Jacobian-vector-products (JVP). We may visualize the flow of this computation as follows:
\[\begin{array}{r l c l c l c l c} \alpha_0 \rightarrow x_0 &\xrightarrow{f_1} & x_1 &\xrightarrow{f_2} & x_2 &\xrightarrow{f_3} &\cdots &\xrightarrow{f_n} & x_n \\ \searrow \quad &\searrow &&\searrow &&\searrow &&\searrow & \\ \dot{x}_0 &\xrightarrow{\text{JVP1}} & \dot{x}_1 &\xrightarrow{\text{JVP2}} & \dot{x}_2 &\xrightarrow{\text{JVP3}} & \cdots &\xrightarrow{\text{JVP}n} & \dot{x}_n \end{array}\]In the initial step $\alpha$ is used to define both $x_0$ and a specific tangent vector $\dot{x}_0$. Both $x_n$ (also referred to as primals) and the tangents $\dot{x}_n$ are then propagated step by step in what is called forward-mode differentiation. Note that the Jacobian in each step may explicitly depend on the input primals – indicated through arrows between the two graphs.
Backward-mode differentiation
Let us consider a situation where we have a scalar function at the end of our computational graph and we want to know its derivative with respect to our input parameters. This is a common situation in machine learning and other optimization scenarios. E.g. for a final loss function $L$ we may write
\[\begin{align} \bar{x}_n &:= \frac{\partial L}{\partial x_n} \\ \bar{x}_{n-1} &:= \frac{\partial L}{\partial x_{n-1}} = \bar{x}_n \frac{\partial x_n}{\partial x_{n-1}} \\ &... \\ \bar{x}_0 &:= \frac{\partial L}{\partial x_{0}} = \bar{x}_1 \frac{\partial x_1}{\partial x_{0}} \end{align}\]We call $\bar{x}_n$ the cotangent of $x_n$ and it has again the same shape as $x_n$. Since we multiply with our Jacobian from the left side, we speak of a vector-Jacobian-product (VJP). JVP and VJP are simply related through a transpose of the Jacobian (and perform the same operation when the Jacobian is symmetric).
We may visualize the flow of the computation through the following graph:
\[\begin{array}{c l c l c l c l c l} x_0 &\xrightarrow{f_1} & x_1 &\xrightarrow{f_2} & x_2 &\xrightarrow{f_3} & \cdots &\xrightarrow{f_n} & x_n &\rightarrow L \\ &\searrow &&\searrow &&\searrow &&\searrow & & \swarrow \\ \bar{x}_0 &\xleftarrow{\text{VJP1}} & \bar{x}_1 &\xleftarrow{\text{VJP2}} & \bar{x}_2 &\xleftarrow{\text{VJP3}} & \cdots &\xleftarrow{\text{VJP}n} & \bar{x}_n & \end{array}\]We first evaluate all the primals $x_n$ in a forward computation until we reach the final loss function $L$. With $L$ we can define the cotangent $\bar{x}_n$ that we are interested in and then propagate it backward step-by-step (aka. back-propagation). An important detail is that each VJP may depend on the primals in some form. Therefore, to implement back-propagation it is necessary to save some residuals of the forward pass and to only discard them once the backward evaluation has reached the corresponding stage. (The residuals could be simply all the input parameters of each function, but they could also be less than that and may be customized in jax.) The management of residuals is the most common reason for implementing custom VJP operations.
Clearly backward differentiation is generally more complex than forward differentiation. However, in most practical applications backward differentiation is more relevant, e.g. for calculating gradients with respect to scalar loss functions.
A small recap on jax.jit tracing
Before we have a look at how autodiff tracing works in jax, let us briefly recap how jax.jit operates.
When jax.jit is called on a function and lowered, jax invokes the function with arguments replaced by tracers (but static arguments kept as they are).
Consider the example below where x becomes a tracer inside of the jitted context and all variables that depend on x also become tracers. Further, any jax functions that are called inside the jitted context create tracers (e.g. consider jnp.ones below). On the other hand the static variable “c” does not get replaced by a tracer.
The tracers are used to create a graph that may represent the computation and then later get compiled into a very efficient GPU program.
Conceptually, jit-tracers do not carry concrete values; they primarily carry shape and dtype information. Therefore, python control flow – which effectively acts as a pre-compiler – cannot depend on the values of the traced arrays.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import jax
import jax.numpy as jnp
from my_jax_utils import show_hlo_info
def f(x, c=3):
print("x", x, type(x))
print("c", c, type(c))
b = jnp.ones((1024,1024))
print("b", b, type(b))
d = x + b
print("d", d, type(d))
e = jnp.sin(d)
print("e", type(e))
return e
f_jit = jax.jit(f, static_argnames=('c',))
x = jnp.ones((1024,1024))
print("Tracing...")
flower = f_jit.lower(x)
1
2
3
4
5
6
Tracing...
x JitTracer<float32[1024,1024]> <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
c 3 <class 'int'>
b JitTracer<float32[1024,1024]> <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
d JitTracer<float32[1024,1024]> <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
e <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
Jit-tracers contain additional attributes (in particular .parent) that keep track of the computational graph. You may think of the tracers as directed edges in a computational graph. Their parent nodes correspond to computational primitives that may have several inputs and outputs. Below you can see how we can traverse the computational graph of the output variable by following the parent relationship of our output tracer:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def f(x, c=3):
b = jnp.ones((1024,1024))
d = x + b
e = jnp.sin(d)
print("e:", e.parent.primitive, "inputs:", e.parent.in_tracers)
get_d = e.parent.in_tracers[0] # effectively gets us a reference of d
print("d:", get_d.parent.primitive, "inputs:", get_d.parent.in_tracers)
get_x = get_d.parent.in_tracers[0] # effectively gets us a reference of x
print("x-parent:", get_x.parent)
return e
f_jit = jax.jit(f, static_argnames=('c',))
x = jnp.ones((1024,1024))
flower = f_jit.lower(x)
1
2
3
e: sin inputs: [JitTracer<float32[1024,1024]>]
d: add inputs: [JitTracer<float32[1024,1024]>, JitTracer<float32[1024,1024]>]
x-parent: None
In this scenario, you can see that we effectively recover the same computational graph that we get after jit compiling the function. (In general, the jit compiled graph may look slightly different, since additional optimizations are performed on the graph.)
1
show_hlo_info(f_jit, x)
1
2
3
4
5
6
7
8
-------- Memory usage of f ---------
const : 4 B
code : 10.9 kB
temp : 0 B
arg : 4.0 MB
output: 4.0 MB
alias : 0 B
peak : 8.0 MB
Auto-diff tracing
When we use jax’s built-in automatic differentiation operators, jax also replaces some arrays by tracers. However, these tracers operate quite differently to jit-tracers and they serve completely different purposes. Auto-diff tracers are used to rearrange function calls and residuals in a convenient manner to facilitate the forward- and backward-differentiation processes.
Most notably, auto-diff tracers carry actual values (not only dtypes/shapes) with them. Consider the following example:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def f(x):
const = jnp.arange(len(x))
out = x*const
print("type(x): ", type(x), "value:", x.primal, x.tangent)
print("type(const): ", type(const), const)
print("type(out): ", type(out), "value:", out.primal, out.tangent)
if x[0] == 1.:
print("Branch taken")
return out
x = jnp.ones((4,))
xtangent = jnp.array([0.1, 0.2, 0.1, 0.2])
print("JVP:")
out_val, out_tangent = jax.jvp(f, (x,), (xtangent,))
print("out_val", out_val, "out_tangent", out_tangent)
1
2
3
4
5
6
JVP:
type(x): <class 'jax._src.interpreters.ad.JVPTracer'> value: [1. 1. 1. 1.] [0.1 0.2 0.1 0.2]
type(const): <class 'jaxlib._jax.ArrayImpl'> [0 1 2 3]
type(out): <class 'jax._src.interpreters.ad.JVPTracer'> value: [0. 1. 2. 3.] [0. 0.2 0.2 0.6]
Branch taken
out_val [0. 1. 2. 3.] out_tangent [0. 0.2 0.2 0.6]
Here we have defined a function and evaluated the JVP of the function. Recall that to evaluate the JVP we need the value of x (referred to as “primal”) and the value of the input tangent vector $\tilde{x}$ (the “tangent”).
Some aspects worth noting here:
- x got replaced by an instance of
jax._src.interpreters.ad.JVPTracer. We will briefly refer to tracers like this as AD-tracers (standing for “auto-diff”) - All values that are inferred from x also got replaced by AD-tracers
- Values that do not depend on x (like
constabove), do not become AD-tracers. (Note that this is different from jax.jit where the result of any computation becomes a tracer.) Their tangents are effectively zero, since they don’t depend on any traced inputs or outputs - Every AD-tracer contains both the primal and the tangent. For JVP tracers the tangents of outputs of any operation are evaluated immediately and stored in the new tracer
- In principle it is possible to have value-dependent branches inside of a differentiated function. However, this is no longer possible when composing with jax.jit and should therefore be avoided
- The result of the JVP call gives the output of the function and the tangents of the output
VJP tracers behave in a slightly more complex manner. Consider the following call to the VJP of the function above
1
2
3
4
5
6
print("VJP:")
out_val, cotangent_pullback = jax.vjp(f, x)
print("out_val", out_val)
print("cotangent_pullback", cotangent_pullback)
output_cotangent = jnp.array(([0, 0.2, 0.2, 0.6]))
print("x_cotangent:", cotangent_pullback(output_cotangent))
1
2
3
4
5
6
7
8
VJP:
type(x): <class 'jax._src.interpreters.ad.LinearizeTracer'> value: [1. 1. 1. 1.] JitTracer<float32[4]>
type(const): <class 'jaxlib._jax.ArrayImpl'> [0 1 2 3]
type(out): <class 'jax._src.interpreters.ad.LinearizeTracer'> value: [0. 1. 2. 3.] JitTracer<float32[4]>
Branch taken
out_val [0. 1. 2. 3.]
cotangent_pullback Partial(_HashableCallableShim(functools.partial(<function _vjp_pullback_wrapper at 0x75f7070ebd80>, 'f', [ShapedArray(float32[4])], (PyTreeDef(*), PyTreeDef((*,))))), Partial(_HashableCallableShim(functools.partial(<function vjp.<locals>.unbound_vjp at 0x75f706516840>, [(ShapedArray(float32[4]), None)], { [34;1mlambda [39;22ma[35m:f32[4][39m; b[35m:f32[4][39m. [34;1mlet[39;22m c[35m:f32[4][39m = mul b a [34;1min [39;22m(c,) })), [Array([0., 1., 2., 3.], dtype=float32)]))
x_cotangent: (Array([0. , 0.2 , 0.4 , 1.8000001], dtype=float32),)
Note that:
- The tracers are now called differently
jax._src.interpreters.ad.LinearizeTracer - They still have a primal value at each intermediate step
- Their tangents are replaced by jit-tracers that do not carry a value. The jit-tracers are used to store a (partial) graph for each operation. The full graph can only be assembled at the end of the function.
- The output of the VJP gives us the primal output and a function. To evaluate the cotangent of x, we need to call the function with the cotangent of the output parameters
- Although we called the VJP with the cotangent initialized with the value of the output tangent from the JVP, we do not get back the input tangent of the JVP. A JVP advects a tangent from input to output space and a VJP a cotangent from output to input space. They don’t reverse each other’s operation, because they operate on different types of tangents.
To appreciate a bit what the jit-tracer at .tangent does, consider the following:
1
2
3
4
5
6
7
8
def f(x):
y = jnp.sin(x)
print(y.tangent.parent.primitive, y.tangent.parent.in_tracers)
ytang_in = y.tangent.parent.in_tracers
print("inputs' parents", ytang_in[0].parent, ytang_in[1].parent)
return y
val, cotangent_fun = jax.vjp(f, jnp.array([0., jnp.pi/3]))
1
2
mul [JitTracer<float32[2]>, JitTracer<float32[2]>]
inputs' parents None None
Note that the jit-tracer of y.tangent comes from a mul operation with two inputs that both have shape(2). Consider that the VJP of a sin function is
Therefore y.tangent is the result of this multiplication with two inputs: the residual $cos(x)$ and the cotangent $\bar{y}$. When we inspect the parents in the graph we can see that these are passed as direct inputs to the computation of $\bar{x}$. This means that $cos(x)$ is already precomputed in the forward pass and then reused in the backward pass.
If the concept of the residual confuses you: don’t worry, we will discuss it in more detail when writing custom VJPs later.
I couldn’t find a direct way to inspect the residuals during tracing / the forward pass. (It might be that they are kept in a python closure.) However, we can see the residuals with jax.ad_checkpoint.print_saved_residuals. You can see that it is indeed the result of a cos operation:
1
2
3
4
5
6
def f(x):
y = jnp.sin(x)
return y
import jax.ad_checkpoint
jax.ad_checkpoint.print_saved_residuals(f, jnp.array([0., jnp.pi/3]))
1
f32[2] output of cos from /tmp/ipykernel_840435/1615555903.py:2:8 (f)
jax.lax.stop_gradient
Sometimes we don’t want some intermediate result to be considered differentiable. For this you can simply use jax.lax.stop_gradient. You can see below that this converts an AD-tracer to its value, thus stopping tangent propagation for this value.
There was a bug in jax that made
jax.lax.stop_gradientnot work properly for integer valued arrays. It was fixed in this issue and you can find a forward/backward compatible version of stop_gradient in the issue’s comments.
1
2
3
4
5
6
7
8
9
10
11
12
13
def f(x):
a = jnp.mean(x, axis=0)
a2 = jax.lax.stop_gradient(a)
print("type(x): ", type(x), "value:", x.primal, x.tangent)
print("type(a): ", type(a), "value:", a.primal, a.tangent)
print("type(a2): ", type(a2), "value:", a2)
return x*a*a2
x = jnp.ones((4,))
xtangent = jnp.array([0.1, 0.2, 0.1, 0.2])
out_val, out_tangent = jax.jvp(f, (x,), (xtangent,))
1
2
3
type(x): <class 'jax._src.interpreters.ad.JVPTracer'> value: [1. 1. 1. 1.] [0.1 0.2 0.1 0.2]
type(a): <class 'jax._src.interpreters.ad.JVPTracer'> value: 1.0 0.15
type(a2): <class 'jaxlib._jax.ArrayImpl'> value: 1.0
Wrappers around VJPs and JVPs
Jax provides a number of convenience functions that wrap around JVPs and VJPs to provide a more convenient interface. In particular, jax.jacfwd and jax.gradient are relevant. It is useful to understand how they operate, in case you ever need slightly modified versions of them.
As you can see below jax.jacfwd is probably implemented simply through multiple JVPs with unit input tangents (using jax’s vmap functionality):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def f(x):
print(type(x), x.primal, x.tangent)
return jnp.sin(x)
def custom_jacfwd(f):
"""Emulates jax.jacfwd"""
# to get the Jacobian we evaluate the JVP of f
# for input tangents [1,0,0,...], [0,1,0,...], ...
def fjac(x):
def fjvp(xtangent):
return jax.jvp(f, (x,), (xtangent,))
input_tangents = jnp.eye(len(x))
out, out_tangents = jax.vmap(fjvp)(input_tangents)
return out_tangents
return fjac
x = jnp.linspace(0.,1., 3)
print("jax.jacfwd")
print(jax.jacfwd(f)(x))
print("My jacfwd")
print(custom_jacfwd(f)(x))
1
2
3
4
5
6
7
8
9
10
jax.jacfwd
<class 'jax._src.interpreters.ad.JVPTracer'> [0. 0.5 1. ] VmapTracer<float32[3]>
[[1. 0. 0. ]
[0. 0.87758255 0. ]
[0. 0. 0.5403023 ]]
My jacfwd
<class 'jax._src.interpreters.ad.JVPTracer'> [0. 0.5 1. ] VmapTracer<float32[3]>
[[1. 0. 0. ]
[0. 0.87758255 0. ]
[0. 0. 0.5403023 ]]
jax.grad is simply a wrapper around VJP that propagates back a unit output tangent of a scalar function
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def loss(x):
print("x:", type(x), x.primal, x.tangent)
return jnp.sum(x * jnp.arange(len(x)))
def custom_grad(f):
"""Emulates jax.grad"""
def grad_loss(x):
vjp_func = jax.vjp(loss, x)[1]
output_cotangent = jnp.array(1.)
input_cotangent = vjp_func(output_cotangent)[0]
return input_cotangent
return grad_loss
print("jax.grad:")
print(jax.grad(loss)(jnp.array([1., 2., 3.])))
print("custom_grad:")
print(custom_grad(loss)(jnp.array([1., 2., 3.])))
1
2
3
4
5
6
jax.grad:
x: <class 'jax._src.interpreters.ad.LinearizeTracer'> [1. 2. 3.] JitTracer<float32[3]>
[0. 1. 2.]
custom_grad:
x: <class 'jax._src.interpreters.ad.LinearizeTracer'> [1. 2. 3.] JitTracer<float32[3]>
[0. 1. 2.]
It should be easy for you to imagine how other wrappers like jax.value_and_grad operate. Higher order derivatives like jax.hessian are likely implemented through recursive / nested tracing (where tracer values are tracers themselves). However, we won’t go that deep here.
Custom JVPs
Since it is relatively easy to efficiently compose most derivative operations from JVPs and VJPs, we will focus for the rest of this investigation on how these operate and how to customize and optimize them.
Since JVPs are fundamentally simpler, we will start with these.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@jax.custom_jvp
def f(x):
print("Calling f: type(x): ", type(x))
return jnp.sin(x)
def f_jvp(primals, tangents):
print("Calling JVP: type(x): ", type(primals[0]))
x, x_tangent = primals[0], tangents[0]
return jnp.sin(x), x_tangent * jnp.cos(x)
f.defjvp(f_jvp)
x = jnp.array([0., jnp.pi/2, jnp.pi])
x_tangent = jnp.array([1., 1., 1.])
print("Evaluating JVP:")
out_val, out_tangent = jax.jvp(f, (x,), (x_tangent,))
print("\nEvaluating f:")
out = f(x)
def function_using_f(x):
print("A function calling f: type(x):", type(x))
y = f(x)
print("type(y):", type(y))
return y
print("\nEvaluating function_using_f with JVP:")
out_val, out_tangent = jax.jvp(function_using_f, (x,), (x_tangent,))
1
2
3
4
5
6
7
8
9
10
Evaluating JVP:
Calling JVP: type(x): <class 'jaxlib._jax.ArrayImpl'>
Evaluating f:
Calling f: type(x): <class 'jaxlib._jax.ArrayImpl'>
Evaluating function_using_f with JVP:
A function calling f: type(x): <class 'jax._src.interpreters.ad.JVPTracer'>
Calling JVP: type(x): <class 'jaxlib._jax.ArrayImpl'>
type(y): <class 'jax._src.interpreters.ad.JVPTracer'>
A few points are worth noting here:
- We had to define two functions for the JVP. One that gets evaluated when using
fnormally and one that is used to evaluate bothfand its JVP at once. - When evaluating the JVP only the latter function gets called (and the former is not called at all). This interface makes sense, since there may be some parts of the evaluation of
fand its JVP that can be reused, so it is often better to define a single function to do both. - Both functions get always called with values (and not tracers) and return values (not tracers)
- Inside of “function_using_f” x is a tracer.
- We infer that the
custom_jvpdecorator handles the interface. If it gets called with an AD-tracer, it callsf_jvpwith the primals/tangents of the tracer and returns the outputs as AD-tracers. If it gets called with a normal array, it simply callsf.
Jax bundles all inputs to the JVP into tuples and provides all non-differentiable inputs separately. The interface may be slightly awkward at times, especially with keyword arguments. Consider the following example
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from functools import partial
@partial(jax.custom_jvp, nondiff_argnames=('a', 'flag'))
def f(x, a=2, flag=True):
if flag:
return x**a
else:
return x**-a
def f_jvp(a, flag, primals, tangents):
x = primals[0]
x_tangent = tangents[0]
if flag:
primal_out = x**a
tangent_out = a * x**(a-1) * x_tangent
else:
primal_out = x**-a
tangent_out = -a * x**(-a-1) * x_tangent
return primal_out, tangent_out
f.defjvp(f_jvp)
def g(x):
return f(x, 3, flag=True) + f(x, 3, flag=False)
jax.jvp(f, (jnp.array([2.,3.,4.]), ), (jnp.array([1.,1.,1.]),))
1
(Array([ 4., 9., 16.], dtype=float32), Array([4., 6., 8.], dtype=float32))
This example works, but it is very simple to mess this up. Note that all the keyword arguments have to be flat arguments in the JVP function. If you wanted something slightly different… e.g. a differentiable keyword argument or a non-differentiable positional argument… feel free to try it… and good luck with that!
To make things simpler, I found it useful to define a nested function with the inner function only containing differentiable arguments. For the example above, we could write the following:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def f(x, a=2, flag=True):
def f_impl(x):
if flag:
return x**a
else:
return x**-a
def f_jvp(primals, tangents):
x = primals[0]
x_tangent = tangents[0]
if flag:
primal_out = x**a
tangent_out = a * x**(a-1) * x_tangent
else:
primal_out = x**-a
tangent_out = -a * x**(-a-1) * x_tangent
return primal_out, tangent_out
f_impl = jax.custom_jvp(f_impl)
f_impl.defjvp(f_jvp)
return f_impl(x)
def g(x):
return f(x, 3, flag=True) + f(x, 3, flag=False)
jax.jvp(f, (jnp.array([2.,3.,4.]), ), (jnp.array([1.,1.,1.]),))
1
(Array([ 4., 9., 16.], dtype=float32), Array([4., 6., 8.], dtype=float32))
While this achieves the same thing, I find it much easier to read. Adding additional arguments is much less scary and optional arguments don’t need to be repeated several times. It has also the additional advantage that only f, but not f_jvp (which we don’t expect anyone to call) is visible at the module level – thus making it for users easier to find the right function. We’ll stick with this pattern for the rest of this tutorial.
As a note of caution I want to mention that it is important to ensure to not take derivatives with respect to one of the other input parameters. If this is done an AD-tracer may leak into the JVP function, causing a tracer error. Anything that may have a derivative should be an argument of the inner function. If you want to call the function with AD-tracers at non-differentiable input locations, explicitly remove their tangents with
jax.lax.stop_gradient.
One very useful aspect of defining JVPs is that jax can automatically infer the VJP from them:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def f(x):
def f_impl(x):
print("Calling f")
return jnp.sin(x)
def f_jvp(primals, tangents):
x, x_tangent = primals[0], tangents[0]
print("Calling JVP", type(x), type(x_tangent))
return jnp.sin(x), x_tangent * jnp.cos(x)
f_impl = jax.custom_jvp(f_impl)
f_impl.defjvp(f_jvp)
return f_impl(x)
val, cotangent_fun = jax.vjp(f, jnp.array([0., jnp.pi/2, jnp.pi]))
output_cotangent = jnp.array([1., 1., 1.])
print("Evaluate Tangent:")
input_cotangent = cotangent_fun(output_cotangent)
1
2
Calling JVP <class 'jaxlib._jax.ArrayImpl'> <class 'jax._src.interpreters.partial_eval.JaxprTracer'>
Evaluate Tangent:
Even though we defined only a JVP, jax is also able to use our function in backward differentiation. Apparently jax has replaced the tangent by a tracer that is used to capture the necessary graph for the backwards pass. How does this work? Possibly through a combination of jax.linearize (to define a linear map on the tangent space) and primitive transposition rules (jax.linear_transpose). Such transposition rules are implemented for all linear primitives in jax.
I am not entirely sure about this, but I suspect that implementing a custom_jvp and using jax’s VJP-through-JVP machinery may give close-to-optimal results for a VJP in many cases.
The opposite – JVP through a
custom_vjpfunction is not possible in jax – probably since the function might not be bijective and the VJP function not contain all the needed information.
Custom VJPs
To define a custom_vjp we need to provide three different functions:
- A function for normal evluation
- A function that is evaluated in the forward pass of the VJP, returning the function value and residuals for the backward pass
- The VJP function that gets the residuals and output-space cotangents and returns the input-space cotangents
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def f(x, y):
def f_impl(x, y):
print("Calling f")
return jnp.sin(x) * jnp.cos(x) * y**2
def f_fwd(x, y):
print("Calling f_fwd")
Sx, Cx = jnp.sin(x), jnp.cos(x)
return Sx*Cx * y**2, (Sx, Cx, y)
def f_vjp(residuals, cotangent):
print("Calling f_vjp")
Sx, Cx, y = residuals
dfdx = (Cx**2 - Sx**2) * y**2
dfdy = Cx * Sx * 2 * y
return (cotangent * dfdx, cotangent * dfdy)
f_impl = jax.custom_vjp(f_impl)
f_impl.defvjp(f_fwd, f_vjp)
return f_impl(x, y)
x = jnp.array([0., jnp.pi/2, jnp.pi])
y = jnp.array([1., 2., 3.])
print("normal eval")
val = f(x, y)
print("\nforward pass")
val, cotangent_fun = jax.vjp(f, x, y)
print("\nbackward pass")
output_cotangent = jnp.array([1., 1., 1.])
x_tangent = cotangent_fun(output_cotangent)
1
2
3
4
5
6
7
8
normal eval
Calling f
forward pass
Calling f_fwd
backward pass
Calling f_vjp
Note that in this example we save the values of sin(x), cos(x) and y as residuals to be reused in the backward pass so that we can avoid re-evaluating these later.
In practice, it would probably be better to just pass
xas a residual here and reevaluatesin(x)andcos(x)later. Passing around data is often more expensive on a GPU than reevaluating it. However, for the sake of this example imagine thatsinandcosrepresent some very expensive calculations that we really don’t want to redo.
You can see that compared to JVPs the main conceptual difference is that our forward function may return residuals that are kept around for the backwards pass. Since residuals may weigh heavily on the memory usage, these may be one of the most important aspects to optimize / minimize when designing good
custom_vjpfunctions.
It is a good idea to verify that our VJP implementation is correct. This can easily be done with jax.test_util.check_grads. It checks this by testing against finite differences:
1
2
from jax.test_util import check_grads
check_grads(f, (x, y), order=1, modes=("rev",))
1
2
3
4
5
Calling f_fwd
Calling f
Calling f
Calling f
Calling f_vjp
It is worth knowing that there exists also an alternative interface for the VJP via jax.custom_gradient. In this case the residuals are passed through a python closure:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@jax.custom_gradient
def f(x, y):
Sx, Cx = jnp.sin(x), jnp.cos(x)
def f_vjp(cotangent):
dfdx = (Cx**2 - Sx**2) * y**2
dfdy = Cx * Sx * 2 * y
return (cotangent * dfdx, cotangent * dfdy)
return Sx*Cx * y**2, f_vjp
x = jnp.array([0., jnp.pi/2, jnp.pi])
y = jnp.array([1., 2., 3.])
check_grads(f, (x, y), order=1, modes=("rev",))
This interface is much easier to understand. However, it is important to know that it may be slightly less optimal in some cases:
- It might be necessary to keep more temporary variables alive
- Unlike the default interface it does not encourage thinking about residuals carefully
- In some scenarios, we might actually want a different way to evaluate the forward pass when needing gradients
Many of these drawbacks may be mitigated by jax.jit.
In practice I’d recommend to start with the
jax.custom_gradientinterface when implementing a VJP. Once you have verified the implementation, you may switch to thejax.custom_vjpinterface if you think there may be benefits through the more fine-grained control. This way you can first focus on the math and later on the slightly awkward VJP interface.
Residuals & Checkpointing
When we chain several functions together and each stores a residual value, the memory usage can get quite high very quickly:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def get_used_bytes():
return jax.devices()[0].memory_stats()["bytes_in_use"]
x = jnp.zeros(1024*256) # ~ 1MB
bytes_start = get_used_bytes()
verbose = True
def print_bytes_diff():
if verbose:
bytes_diff = get_used_bytes() - bytes_start
print(f"MB used: {bytes_diff / 1024/1024:.1f}")
def f(x):
def f_impl(x):
return x*x
def f_fwd(x):
print_bytes_diff()
return x*x, x
def f_vjp(x_res, cotangent):
print_bytes_diff()
return (cotangent * 2*x_res,)
f_impl = jax.custom_vjp(f_impl)
f_impl.defvjp(f_fwd, f_vjp)
return f_impl(x)
def g(x):
x = f(x)
x = f(x)
x = f(x)
x = f(x)
x = f(x)
return x
print("forward:")
y, fvjp = jax.vjp(g, x)
print("\nbackward:")
x_cotangent = fvjp(jnp.ones_like(x))
print("\nresiduals:")
verbose = False
jax.ad_checkpoint.print_saved_residuals(g, x)
del x, y, fvjp, x_cotangent # to make results consistent when re-running
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
forward:
MB used: 0.0
MB used: 1.0
MB used: 2.0
MB used: 3.0
MB used: 4.0
backward:
MB used: 6.0
MB used: 7.0
MB used: 7.0
MB used: 7.0
MB used: 7.0
residuals:
f32[262144] from the argument x
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
The memory usage increases step by step during the forward pass, since more and more residuals get accumulated. As mentioned earlier, we can see all of the residuals by inspecting them with jax.ad_checkpoint.print_saved_residuals.
Using
jax.jitcan help eliminate many residuals. We will discuss this later.
It is possible to avoid storing residuals and instead instruct jax to recalculate them during the backward pass by using jax.checkpoint:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = jnp.zeros(1024*256) # ~ 1MB
bytes_start = get_used_bytes()
verbose = True
print("forward:")
y, fvjp = jax.vjp(jax.checkpoint(g), x)
print("\nbackward:")
x_cotangent = fvjp(jnp.ones_like(x))
print("\nresiduals:")
verbose = False
jax.ad_checkpoint.print_saved_residuals(jax.checkpoint(g), x)
del x, y, fvjp, x_cotangent # to make results consistent when re-running
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
forward:
MB used: 0.0
MB used: 0.0
MB used: 0.0
MB used: 0.0
MB used: 0.0
backward:
MB used: 2.0
MB used: 2.0
MB used: 2.0
MB used: 2.0
MB used: 2.0
residuals:
f32[262144] from the argument x
Checkpointing is a quite sophisticated feature with many options. Consider having a look at the documentation of jax.checkpoint and this in-depth-guide.
Note that forward differentiation trivially does not exhibit a growth in memory requirements, since no residuals are needed:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def f(x):
def f_impl(x):
print_bytes_diff()
return x*x
def f_jvp(primals, tangents):
x, tangent = primals[0], tangents[0]
print_bytes_diff()
return x*x, tangent*2*x
f_impl = jax.custom_jvp(f_impl)
f_impl.defjvp(f_jvp)
return f_impl(x)
def g(x):
x = f(x)
x = f(x)
x = f(x)
x = f(x)
x = f(x)
return x
x = jnp.zeros(1024*256) # ~ 1MB
bytes_start = get_used_bytes()
verbose = True
y, fjvp = jax.jvp(g, (x,), (jnp.ones_like(x),))
1
2
3
4
5
MB used: 1.0
MB used: 3.0
MB used: 3.0
MB used: 3.0
MB used: 3.0
Gradients and jax.jit
jax.jit can be quite good at eliminating residuals and fusing gradient computation and function evaluation into very few kernels. Consider e.g. our example from before:
1
2
3
4
5
6
7
8
9
def f(x):
return x*x
def g(x):
return jnp.sum(f(f(f(f(f(x))))))
x = jnp.zeros(1024*256) # ~ 1MB
show_hlo_info(jax.jit(jax.value_and_grad(g)), x)
1
2
3
4
5
6
7
8
-------- Memory usage of g ---------
const : 8 B
code : 9.7 kB
temp : 2.0 kB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 2.0 MB
However, things get worse when evaluating slightly more expensive functions (consider the temporary memory). In this case jax.checkpoint may help:
1
2
3
4
5
6
7
8
9
10
11
12
13
def f(x):
return jnp.sin(jnp.cos(x*x))
def g(x):
return jnp.sum(f(f(f(f(f(x))))))
x = jnp.zeros(1024*256) # ~ 1MB
show_hlo_info(jax.jit(jax.value_and_grad(g)), x, mode="mem")
print("With checkpointing:")
show_hlo_info(jax.jit(jax.value_and_grad(jax.checkpoint(g))), x, mode="mem")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
-------- Memory usage of g ---------
const : 8 B
code : 114.2 kB
temp : 8.0 MB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 10.0 MB
With checkpointing:
-------- Memory usage of g ---------
const : 8 B
code : 130.9 kB
temp : 16 B
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 2.0 MB
Loops
Jax’s backward autodiff tends to have horrendous memory requirements when differentiating loops – almost always storing all intermediate values.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def f(i, x):
return x*x + 1
def g(x):
xfin = jax.lax.fori_loop(0, 20, f, x)
return jnp.sum(xfin)
def g_chunked(x):
def g5loops(i, x):
xfin = jax.lax.fori_loop(i, i+5, f, x)
return xfin
g5loops = jax.checkpoint(g5loops, static_argnums=0)
for i in range(4):
x = g5loops(i*5, x)
return jnp.sum(x)
x = jnp.ones((1024,256)) # ~ 1MB
show_hlo_info(jax.jit(jax.grad(g)), x, mode="mem") # Memory usage grows with nloops
show_hlo_info(jax.jit(jax.grad(jax.checkpoint(g))), x, mode="mem") # Checkpointing doesn't help here (why?)
show_hlo_info(jax.jit(jax.grad(g_chunked)), x, mode="mem") # Chunking also doesn't help (why?) ## I guess jit messed it up?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
-------- Memory usage of g ---------
const : 12 B
code : 12.0 kB
temp : 20.0 MB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 22.0 MB
-------- Memory usage of g ---------
const : 12 B
code : 12.5 kB
temp : 20.0 MB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 22.0 MB
-------- Memory usage of g_chunked ---------
const : 48 B
code : 12.2 kB
temp : 24.0 MB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 26.0 MB
I didn’t find a way to use checkpointing to avoid this. Please let me know if you know how to do this!
On the other hand JVPs work fine with differentiating loops – since it is trivial to consume intermediate data immediately:
(This doesn’t mean however, that we can replace the gradient by a forward Jacobian in this case)
1
2
3
4
5
6
def jvp_g(x):
return jax.jvp(g, (x,), (jnp.ones_like(x),))
show_hlo_info(jax.jit(jvp_g), x, mode="mem")
show_hlo_info(jax.jit(jax.jacfwd(g)), x, mode="mem") # explodes badly
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
-------- Memory usage of jvp_g ---------
const : 16 B
code : 13.0 kB
temp : 2.0 MB
arg : 1.0 MB
output: 24 B
alias : 0 B
peak : 3.0 MB
-------- Memory usage of g ---------
const : 8 B
code : 15.2 kB
temp : 256.5 GB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 256.5 GB
W1208 17:13:29.916063 840435 hlo_rematerialization.cc:3204] Can't reduce memory use below 5.70GiB (6116022681 bytes) by rematerialization; only reduced to 256.50GiB (275415826432 bytes), down from 256.50GiB (275415826432 bytes) originally
For the scenario at hand, the loop is reversible. Therefore, we can solve the problem by defining a custom_vjp that reconstructs the intermediate values one-by-one, starting from the endpoint:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def reversible_loop_with_vjp(body_fwd, body_bwd, init_val, nsteps):
"""Similar to jax.lax.fori_loop, but using a reversible approach in the VJP to save memory.
body_fwd: the normal for loop body
body_bwd: loop body that reverses body_fwd
"""
@jax.custom_vjp
def f(init_val):
return jax.lax.fori_loop(0, nsteps, body_fwd, init_val)
def f_fwd(init_val):
final_val = f(init_val)
return final_val, final_val
def f_bwd(final_val, cotangent):
def bwd_step(i, carry):
val, cot = carry
prev_val = body_bwd(nsteps-i-1, val)
prev_cot = jax.vjp(body_fwd, nsteps-i-1, prev_val)[1](cot)[1]
return prev_val, prev_cot
initial_val, initial_cot = jax.lax.fori_loop(0, nsteps, bwd_step, (final_val, cotangent))
return initial_cot,
f.defvjp(f_fwd, f_bwd)
return f(init_val)
def step_fwd(i, x):
return x*x + 1e-5
def step_bwd(i, x):
return jnp.sqrt(x - 1e-5)
def g(x):
xfin = reversible_loop_with_vjp(step_fwd, step_bwd, x, 5)
return jnp.mean(xfin)
x = jnp.ones((1024,256))*0.1 # ~ 1MB
jax.test_util.check_grads(g, (x,), order=1, modes=("rev",))
show_hlo_info(jax.jit(jax.grad(g)), x, mode="mem")
1
2
3
4
5
6
7
8
-------- Memory usage of g ---------
const : 16 B
code : 10.5 kB
temp : 1.0 MB
arg : 1.0 MB
output: 1.0 MB
alias : 0 B
peak : 3.0 MB
jax.linear_transpose
If you ever get confused during writing a VJP about what the adjoint of some operation may be, it is useful to know about jax.linear_transpose. You can experiment with it to understand what the adjoint of many linear operations is:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
print("adjoint of sum (= outer product with ones)")
def f(x):
return jnp.sum(x)
x = jnp.linspace(0.,1., 10)
fT = jax.linear_transpose(f, x)
print(fT(2.)[0])
print("adjoint of gather (= scatter)")
isort = jnp.argsort(x)
def f(x):
return x[isort]
x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
fT = jax.linear_transpose(f, x)
xsort = f(x)
print(fT(xsort)[0])
print(jnp.zeros_like(x).at[isort].add(xsort)) # gives the same
1
2
3
4
5
6
7
adjoint of sum (= outer product with ones)
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
adjoint of gather (= scatter)
[0.947667 0.9785799 0.33229148 0.46866846 0.5698887 0.16550303
0.3101946 0.68948054 0.74676657 0.17101455]
[0.947667 0.9785799 0.33229148 0.46866846 0.5698887 0.16550303
0.3101946 0.68948054 0.74676657 0.17101455]
Another use-case may be when writing custom_vjps to transpose some linear sub-computations with jax’s built-in system.
Some applications
Direct force summation
We want to solve gravitational forces through direct summation. For this I have written a function that processes particles in chunks below to avoid creating an $N^2$ matrix in memory
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def direct_force_scan(x, m=1., n2lim=1e8, eps=1e-2):
"""Implements a direct N-body force calculation using jax.lax.scan to limit memory usage."""
N = x.shape[0]
nmax = int(np.ceil(n2lim / len(x)))
nev = int(np.ceil(x.shape[0] / nmax))
def force_segment(_, i):
isel = jnp.arange(nmax, dtype=jnp.int32) + i*nmax
dx = x - x[isel,None]
r_ij2 = jnp.sum(dx**2, axis=-1)
rinv = 1./jnp.sqrt(r_ij2 + eps**2)
return None, - jnp.sum(dx * rinv[...,None]**3 * m, axis=1)
_, forces = jax.lax.scan(force_segment, None, jnp.arange(nev, dtype=jnp.int32))
return jnp.concatenate(forces)[0:N]
xpart = jax.random.normal(jax.random.PRNGKey(0), (1024*32, 3))
fpart = direct_force_scan(xpart)
print(fpart.shape)
1
(32768, 3)
While it is possible to evaluate the JVP very well with autodiff, the VJP fails miserably:
1
2
3
4
5
6
7
8
9
10
11
# JVP works fine:
val, tang = jax.jvp(direct_force_scan, (xpart,), (jnp.ones_like(xpart),))
def fjvp(x):
return jax.jvp(direct_force_scan, (x,), (jnp.ones_like(x),))
def fgrad(x):
return jax.grad(lambda x: jnp.sum(direct_force_scan(x)))(x)
show_hlo_info(jax.jit(fjvp), xpart, mode="mem") # JVP works fine
show_hlo_info(jax.jit(fgrad), xpart, mode="mem") # VJP consumes 53GB of memory!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
-------- Memory usage of fjvp ---------
const : 64 B
code : 25.9 kB
temp : 772.8 MB
arg : 384.0 kB
output: 768.0 kB
alias : 0 B
peak : 773.9 MB
W1208 17:13:32.239346 840435 hlo_rematerialization.cc:3204] Can't reduce memory use below 5.70GiB (6116612505 bytes) by rematerialization; only reduced to 53.28GiB (57205066996 bytes), down from 53.28GiB (57205066996 bytes) originally
-------- Memory usage of fgrad ---------
const : 100 B
code : 68.8 kB
temp : 53.3 GB
arg : 384.0 kB
output: 384.0 kB
alias : 0 B
peak : 53.3 GB
Since the VJP fails so badly, we should consider implementing a custom_vjp. We could do this by writing a custom-loop that step-by-step adds together the cotangents.
However, there is a much easier way out here. The Jacobian of the force is symmetric:
\(\begin{align} \nabla_{x_k} \vec{F}_i &= - \sum_{i \neq j} \nabla_{x_k} \nabla_{x_i} \phi(\vec{x}_i - \vec{x}_j) \nonumber \\ &= \sum_{i \neq j} \delta_{ik} {T}_{ij} + \delta_{jk} {T}_{ij} \nonumber \\ &= \nabla_{x_i} \vec{F}_k \end{align}\) with the tidal tensor ${T}_{ij} = -\nabla \nabla \phi (\vec{x} = \vec{x}_i - \vec{x}_j) $ and $\phi$ as the Green’s function of the gravitational potential. Therefore, VJP and JVP are identical: \(\begin{align} \sum_i g_i \frac{\partial F_i}{\partial \vec{x}_k} = \sum_i \frac{\partial F_k}{\partial \vec{x}_i} g_i \end{align}\) We may therefore simply replace the VJP by an invocation of the JVP – which we know that jax handles efficiently:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def direct_force_scan_with_vjp(x, m=1., n2lim=1e8, eps=1e-1):
def f_impl(x):
return direct_force_scan(x, m, n2lim=n2lim, eps=eps)
@jax.custom_vjp
def f(x):
return f_impl(x)
def ffwd(x):
return f_impl(x), x
def fbwd(x_res, cotangent):
return jax.jvp(f_impl, (x_res,), (cotangent,))[1],
f.defvjp(ffwd, fbwd)
return f(x)
def loss(x):
return jnp.sum(direct_force_scan_with_vjp(x)**2)
from jax.test_util import check_grads
check_grads(jax.jit(loss), (xpart,), order=1, modes=("rev",), eps=1e-2, rtol=1e-1)
show_hlo_info(jax.jit(jax.grad(loss)), xpart, mode="mem")
1
2
3
4
5
6
7
8
-------- Memory usage of loss ---------
const : 80 B
code : 35.2 kB
temp : 768.3 MB
arg : 384.0 kB
output: 384.0 kB
alias : 0 B
peak : 769.0 MB
We conclude that it can be viable to replace VJPs by JVPs when we have the option – because JVPs tend to work much better by default!
N-body simulation
The Jacobian of the mapping from initial to final positions/velocities in an N-body simulation is symplectic. We can use the symplectic property to relate the VJP to the JVP of the inverse. All it requires is rotating position and velocity cotangents at the beginning and the end of the backward pass.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def nbody_sim(pos, vel, dt, nsteps):
def run_impl(pos, vel, dt):
def kick_drift_kick_step(i, state):
x, v = state
v = v + direct_force_scan(x) * 0.5*dt
x = x + v * dt
v = v + direct_force_scan(x) * 0.5*dt
return x, v
return jax.lax.fori_loop(0, nsteps, kick_drift_kick_step, (pos, vel))
@jax.custom_vjp
def run(pos, vel):
return run_impl(pos, vel, dt)
def run_fwd(pos, vel):
pos_vel = run_impl(pos, vel, dt)
return pos_vel, pos_vel
def run_bwd(pos_vel, cot_pos_vel):
pos, vel = pos_vel
cpos, cvel = cot_pos_vel
pos_vel0, (cpos0, cvel0) = jax.jvp(run_impl, (pos, vel, -dt), (-cvel, cpos, 0.))
return (cvel0, -cpos0)
run.defvjp(run_fwd, run_bwd)
return run(pos, vel)
nbody_sim.jit = jax.jit(nbody_sim)
xpart = jax.random.normal(jax.random.PRNGKey(0), (1024*32, 3))
vpart = jax.random.normal(jax.random.PRNGKey(1), (1024*32, 3))
xfinal, vfinal = nbody_sim.jit(xpart, vpart, dt=1e-3, nsteps=10)
def loss(x, v):
xfinal, vfinal = nbody_sim(x, v, dt=1e-3, nsteps=10)
return jnp.mean(xfinal**2) + jnp.mean(vfinal**2)
grad = jax.jit(jax.grad(loss))(xpart, vpart)
check_grads(jax.jit(loss), (xpart, vpart), order=1, modes=("rev",), eps=1e-3, rtol=1e-1)
show_hlo_info(jax.jit(jax.grad(loss)), xpart, vpart, mode="mem")
1
2
3
4
5
6
7
8
-------- Memory usage of loss ---------
const : 236 B
code : 53.1 kB
temp : 774.7 MB
arg : 768.0 kB
output: 384.0 kB
alias : 0 B
peak : 775.8 MB
Note that we didn’t even need the VJP of the force here (since we are implicitly requesting a JVP)!