Post

Jax's auto-diff

How derivatives are traced in JAX

Jax's auto-diff

Derivative tracing in jax

This guide provides a low-level look into how derivatives are traced in jax. In particular we will cover…

  • What Jacobian-vector-products (JVP) and vector-Jacobian-products (VJP) are and why they form the basis of jax’s automatic differentiation pipeline
  • How auto-diff tracing works and what distinguishes it from jax’s jit-tracing
  • What residuals are and how they sometimes blow up memory usage during backwards-propagation
  • How and why to implement custom JVP and VJP rules
  • How and why VJP rules for problems with certain symmetries may be replaced by simpler JVPs

Prerequisites

I assume that you have already used jax for a bit and maybe also tried out jax.jit and jax.grad here and there. If you didn’t, it may be a good idea to play around a bit with jax and then come back when you really want to understand in more detail how auto-diff works.

You can find the notebook that this guide is based on here.

Forward-mode differentiation

Consider a nested function of the form

\[x_n = f_{n}(... f_3(f_2(f_1(x_0))))\]

We might write its evaluation as a series of evaluation steps

\[x_0 \xrightarrow{f_1} x_1 \xrightarrow{f_2} x_2 \xrightarrow{f_3} ... \xrightarrow{f_{n}} x_n\]

We are interested in taking derivatives of the final values with respect to the initial values. For now, let us assume that (the vector) $x_0$ may be written as a function of a single scalar value $\alpha$.

In that case we can write the final derivative in terms of the initial one: \(\begin{align} \frac{\partial x_n}{\partial \alpha} &= ? \\ \dot{x}_0 &:= \frac{\partial x_0}{\partial \alpha} \\ \dot{x}_1 &:= \frac{\partial x_1}{\partial \alpha} = \frac{\partial x_1}{\partial x_0} \frac{\partial x_0}{\partial \alpha} = \frac{\partial x_1}{\partial x_0} \dot{x}_0 \\ \dot{x}_n &:= \frac{\partial x_n}{\partial \alpha} = \frac{\partial x_n}{\partial x_{n-1}} \dot{x}_{n-1}\\ \end{align}\) Here we multiply the Jacobian $J_n(x_{n-1}) = \frac{\partial x_n}{\partial x_{n-1}} $ with a tangent vector \(\dot{x}_{n-1}\) . The tangent vector has the same shape as $x_{n-1}$. We call these evaluation steps Jacobian-vector-products (JVP). We may visualize the flow of this computation as follows:

\[\begin{array}{r l c l c l c l c} \alpha_0 \rightarrow x_0 &\xrightarrow{f_1} & x_1 &\xrightarrow{f_2} & x_2 &\xrightarrow{f_3} &\cdots &\xrightarrow{f_n} & x_n \\ \searrow \quad &\searrow &&\searrow &&\searrow &&\searrow & \\ \dot{x}_0 &\xrightarrow{\text{JVP1}} & \dot{x}_1 &\xrightarrow{\text{JVP2}} & \dot{x}_2 &\xrightarrow{\text{JVP3}} & \cdots &\xrightarrow{\text{JVP}n} & \dot{x}_n \end{array}\]

In the initial step $\alpha$ is used to define both $x_0$ and a specific tangent vector $\dot{x}_0$. Both $x_n$ (also referred to as primals) and the tangents $\dot{x}_n$ are then propagated step by step in what is called forward-mode differentiation. Note that the Jacobian in each step may explicitly depend on the input primals – indicated through arrows between the two graphs.

Backward-mode differentiation

Let us consider a situation where we have a scalar function at the end of our computational graph and we want to know its derivative with respect to our input parameters. This is a common situation in machine learning and other optimization scenarios. E.g. for a final loss function $L$ we may write

\[\begin{align} \bar{x}_n &:= \frac{\partial L}{\partial x_n} \\ \bar{x}_{n-1} &:= \frac{\partial L}{\partial x_{n-1}} = \bar{x}_n \frac{\partial x_n}{\partial x_{n-1}} \\ &... \\ \bar{x}_0 &:= \frac{\partial L}{\partial x_{0}} = \bar{x}_1 \frac{\partial x_1}{\partial x_{0}} \end{align}\]

We call $\bar{x}_n$ the cotangent of $x_n$ and it has again the same shape as $x_n$. Since we multiply with our Jacobian from the left side, we speak of a vector-Jacobian-product (VJP). JVP and VJP are simply related through a transpose of the Jacobian (and perform the same operation when the Jacobian is symmetric).

We may visualize the flow of the computation through the following graph:

\[\begin{array}{c l c l c l c l c l} x_0 &\xrightarrow{f_1} & x_1 &\xrightarrow{f_2} & x_2 &\xrightarrow{f_3} & \cdots &\xrightarrow{f_n} & x_n &\rightarrow L \\ &\searrow &&\searrow &&\searrow &&\searrow & & \swarrow \\ \bar{x}_0 &\xleftarrow{\text{VJP1}} & \bar{x}_1 &\xleftarrow{\text{VJP2}} & \bar{x}_2 &\xleftarrow{\text{VJP3}} & \cdots &\xleftarrow{\text{VJP}n} & \bar{x}_n & \end{array}\]

We first evaluate all the primals $x_n$ in a forward computation until we reach the final loss function $L$. With $L$ we can define the cotangent $\bar{x}_n$ that we are interested in and then propagate it backward step-by-step (aka. back-propagation). An important detail is that each VJP may depend on the primals in some form. Therefore, to implement back-propagation it is necessary to save some residuals of the forward pass and to only discard them once the backward evaluation has reached the corresponding stage. (The residuals could be simply all the input parameters of each function, but they could also be less than that and may be customized in jax.) The management of residuals is the most common reason for implementing custom VJP operations.

Clearly backward differentiation is generally more complex than forward differentiation. However, in most practical applications backward differentiation is more relevant, e.g. for calculating gradients with respect to scalar loss functions.

A small recap on jax.jit tracing

Before we have a look at how autodiff tracing works in jax, let us briefly recap how jax.jit operates.

When jax.jit is called on a function and lowered, jax invokes the function with arguments replaced by tracers (but static arguments kept as they are).

Consider the example below where x becomes a tracer inside of the jitted context and all variables that depend on x also become tracers. Further, any jax functions that are called inside the jitted context create tracers (e.g. consider jnp.ones below). On the other hand the static variable “c” does not get replaced by a tracer.

The tracers are used to create a graph that may represent the computation and then later get compiled into a very efficient GPU program.

Conceptually, jit-tracers do not carry concrete values; they primarily carry shape and dtype information. Therefore, python control flow – which effectively acts as a pre-compiler – cannot depend on the values of the traced arrays.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import jax
import jax.numpy as jnp
from my_jax_utils import show_hlo_info

def f(x, c=3):
    print("x", x, type(x))
    print("c", c, type(c))
    b = jnp.ones((1024,1024))
    print("b", b, type(b))
    d = x + b
    print("d", d, type(d))
    e = jnp.sin(d)
    print("e", type(e))
    return e
f_jit = jax.jit(f, static_argnames=('c',))

x = jnp.ones((1024,1024))

print("Tracing...")
flower = f_jit.lower(x)
1
2
3
4
5
6
Tracing...
x JitTracer<float32[1024,1024]> <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
c 3 <class 'int'>
b JitTracer<float32[1024,1024]> <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
d JitTracer<float32[1024,1024]> <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
e <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

Jit-tracers contain additional attributes (in particular .parent) that keep track of the computational graph. You may think of the tracers as directed edges in a computational graph. Their parent nodes correspond to computational primitives that may have several inputs and outputs. Below you can see how we can traverse the computational graph of the output variable by following the parent relationship of our output tracer:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def f(x, c=3):
    b = jnp.ones((1024,1024))
    d = x + b
    e = jnp.sin(d)

    print("e:", e.parent.primitive, "inputs:", e.parent.in_tracers)
    get_d = e.parent.in_tracers[0] # effectively gets us a reference of d
    print("d:", get_d.parent.primitive, "inputs:", get_d.parent.in_tracers)
    get_x = get_d.parent.in_tracers[0] # effectively gets us a reference of x
    print("x-parent:", get_x.parent)

    return e
f_jit = jax.jit(f, static_argnames=('c',))

x = jnp.ones((1024,1024))

flower = f_jit.lower(x)
1
2
3
e: sin inputs: [JitTracer<float32[1024,1024]>]
d: add inputs: [JitTracer<float32[1024,1024]>, JitTracer<float32[1024,1024]>]
x-parent: None

In this scenario, you can see that we effectively recover the same computational graph that we get after jit compiling the function. (In general, the jit compiled graph may look slightly different, since additional optimizations are performed on the graph.)

1
show_hlo_info(f_jit, x)
1
2
3
4
5
6
7
8
--------  Memory usage of f  ---------
const : 4 B
code  : 10.9 kB
temp  : 0 B
arg   : 4.0 MB
output: 4.0 MB
alias : 0 B
peak  : 8.0 MB

svg

Auto-diff tracing

When we use jax’s built-in automatic differentiation operators, jax also replaces some arrays by tracers. However, these tracers operate quite differently to jit-tracers and they serve completely different purposes. Auto-diff tracers are used to rearrange function calls and residuals in a convenient manner to facilitate the forward- and backward-differentiation processes.

Most notably, auto-diff tracers carry actual values (not only dtypes/shapes) with them. Consider the following example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def f(x):
    const = jnp.arange(len(x))
    out = x*const
    print("type(x): ", type(x), "value:", x.primal, x.tangent)
    print("type(const): ", type(const), const)
    print("type(out): ", type(out), "value:",  out.primal, out.tangent)
    if x[0] == 1.:
        print("Branch taken")

    return out

x = jnp.ones((4,))
xtangent = jnp.array([0.1, 0.2, 0.1, 0.2])

print("JVP:")
out_val, out_tangent = jax.jvp(f, (x,), (xtangent,))
print("out_val", out_val, "out_tangent", out_tangent)
1
2
3
4
5
6
JVP:
type(x):  <class 'jax._src.interpreters.ad.JVPTracer'> value: [1. 1. 1. 1.] [0.1 0.2 0.1 0.2]
type(const):  <class 'jaxlib._jax.ArrayImpl'> [0 1 2 3]
type(out):  <class 'jax._src.interpreters.ad.JVPTracer'> value: [0. 1. 2. 3.] [0.  0.2 0.2 0.6]
Branch taken
out_val [0. 1. 2. 3.] out_tangent [0.  0.2 0.2 0.6]

Here we have defined a function and evaluated the JVP of the function. Recall that to evaluate the JVP we need the value of x (referred to as “primal”) and the value of the input tangent vector $\tilde{x}$ (the “tangent”).

Some aspects worth noting here:

  • x got replaced by an instance of jax._src.interpreters.ad.JVPTracer. We will briefly refer to tracers like this as AD-tracers (standing for “auto-diff”)
  • All values that are inferred from x also got replaced by AD-tracers
  • Values that do not depend on x (like const above), do not become AD-tracers. (Note that this is different from jax.jit where the result of any computation becomes a tracer.) Their tangents are effectively zero, since they don’t depend on any traced inputs or outputs
  • Every AD-tracer contains both the primal and the tangent. For JVP tracers the tangents of outputs of any operation are evaluated immediately and stored in the new tracer
  • In principle it is possible to have value-dependent branches inside of a differentiated function. However, this is no longer possible when composing with jax.jit and should therefore be avoided
  • The result of the JVP call gives the output of the function and the tangents of the output

VJP tracers behave in a slightly more complex manner. Consider the following call to the VJP of the function above

1
2
3
4
5
6
print("VJP:")
out_val, cotangent_pullback = jax.vjp(f, x)
print("out_val", out_val)
print("cotangent_pullback", cotangent_pullback)
output_cotangent = jnp.array(([0, 0.2, 0.2, 0.6]))
print("x_cotangent:", cotangent_pullback(output_cotangent))
1
2
3
4
5
6
7
8
VJP:
type(x):  <class 'jax._src.interpreters.ad.LinearizeTracer'> value: [1. 1. 1. 1.] JitTracer<float32[4]>
type(const):  <class 'jaxlib._jax.ArrayImpl'> [0 1 2 3]
type(out):  <class 'jax._src.interpreters.ad.LinearizeTracer'> value: [0. 1. 2. 3.] JitTracer<float32[4]>
Branch taken
out_val [0. 1. 2. 3.]
cotangent_pullback Partial(_HashableCallableShim(functools.partial(<function _vjp_pullback_wrapper at 0x75f7070ebd80>, 'f', [ShapedArray(float32[4])], (PyTreeDef(*), PyTreeDef((*,))))), Partial(_HashableCallableShim(functools.partial(<function vjp.<locals>.unbound_vjp at 0x75f706516840>, [(ShapedArray(float32[4]), None)], { lambda a:f32[4]; b:f32[4]. let c:f32[4] = mul b a in (c,) })), [Array([0., 1., 2., 3.], dtype=float32)]))
x_cotangent: (Array([0.       , 0.2      , 0.4      , 1.8000001], dtype=float32),)

Note that:

  • The tracers are now called differently jax._src.interpreters.ad.LinearizeTracer
  • They still have a primal value at each intermediate step
  • Their tangents are replaced by jit-tracers that do not carry a value. The jit-tracers are used to store a (partial) graph for each operation. The full graph can only be assembled at the end of the function.
  • The output of the VJP gives us the primal output and a function. To evaluate the cotangent of x, we need to call the function with the cotangent of the output parameters
  • Although we called the VJP with the cotangent initialized with the value of the output tangent from the JVP, we do not get back the input tangent of the JVP. A JVP advects a tangent from input to output space and a VJP a cotangent from output to input space. They don’t reverse each other’s operation, because they operate on different types of tangents.

To appreciate a bit what the jit-tracer at .tangent does, consider the following:

1
2
3
4
5
6
7
8
def f(x):
    y = jnp.sin(x)
    print(y.tangent.parent.primitive, y.tangent.parent.in_tracers)
    ytang_in = y.tangent.parent.in_tracers
    print("inputs' parents", ytang_in[0].parent, ytang_in[1].parent)
    return y

val, cotangent_fun = jax.vjp(f, jnp.array([0., jnp.pi/3]))
1
2
mul [JitTracer<float32[2]>, JitTracer<float32[2]>]
inputs' parents None None

Note that the jit-tracer of y.tangent comes from a mul operation with two inputs that both have shape(2). Consider that the VJP of a sin function is

\[\bar{x} = \cos(x) \cdot \bar{y} = \text{residual} \cdot \bar{y}\]

Therefore y.tangent is the result of this multiplication with two inputs: the residual $cos(x)$ and the cotangent $\bar{y}$. When we inspect the parents in the graph we can see that these are passed as direct inputs to the computation of $\bar{x}$. This means that $cos(x)$ is already precomputed in the forward pass and then reused in the backward pass.

If the concept of the residual confuses you: don’t worry, we will discuss it in more detail when writing custom VJPs later.

I couldn’t find a direct way to inspect the residuals during tracing / the forward pass. (It might be that they are kept in a python closure.) However, we can see the residuals with jax.ad_checkpoint.print_saved_residuals. You can see that it is indeed the result of a cos operation:

1
2
3
4
5
6
def f(x):
    y = jnp.sin(x)
    return y

import jax.ad_checkpoint
jax.ad_checkpoint.print_saved_residuals(f, jnp.array([0., jnp.pi/3]))
1
f32[2] output of cos from /tmp/ipykernel_840435/1615555903.py:2:8 (f)

jax.lax.stop_gradient

Sometimes we don’t want some intermediate result to be considered differentiable. For this you can simply use jax.lax.stop_gradient. You can see below that this converts an AD-tracer to its value, thus stopping tangent propagation for this value.

There was a bug in jax that made jax.lax.stop_gradient not work properly for integer valued arrays. It was fixed in this issue and you can find a forward/backward compatible version of stop_gradient in the issue’s comments.

1
2
3
4
5
6
7
8
9
10
11
12
13
def f(x):
    a = jnp.mean(x, axis=0)
    a2 = jax.lax.stop_gradient(a)
    print("type(x): ", type(x), "value:", x.primal, x.tangent)
    print("type(a): ", type(a), "value:", a.primal, a.tangent)
    print("type(a2): ", type(a2), "value:", a2)

    return x*a*a2

x = jnp.ones((4,))
xtangent = jnp.array([0.1, 0.2, 0.1, 0.2])

out_val, out_tangent = jax.jvp(f, (x,), (xtangent,))
1
2
3
type(x):  <class 'jax._src.interpreters.ad.JVPTracer'> value: [1. 1. 1. 1.] [0.1 0.2 0.1 0.2]
type(a):  <class 'jax._src.interpreters.ad.JVPTracer'> value: 1.0 0.15
type(a2):  <class 'jaxlib._jax.ArrayImpl'> value: 1.0

Wrappers around VJPs and JVPs

Jax provides a number of convenience functions that wrap around JVPs and VJPs to provide a more convenient interface. In particular, jax.jacfwd and jax.gradient are relevant. It is useful to understand how they operate, in case you ever need slightly modified versions of them.

As you can see below jax.jacfwd is probably implemented simply through multiple JVPs with unit input tangents (using jax’s vmap functionality):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def f(x):
    print(type(x), x.primal, x.tangent)
    return jnp.sin(x)

def custom_jacfwd(f):
    """Emulates jax.jacfwd"""
    # to get the Jacobian we evaluate the JVP of f 
    # for input tangents [1,0,0,...], [0,1,0,...], ...

    def fjac(x):
        def fjvp(xtangent):
            return jax.jvp(f, (x,), (xtangent,))

        input_tangents = jnp.eye(len(x))

        out, out_tangents = jax.vmap(fjvp)(input_tangents)
        return out_tangents
    
    return fjac

x = jnp.linspace(0.,1., 3)
print("jax.jacfwd")
print(jax.jacfwd(f)(x))
print("My jacfwd")
print(custom_jacfwd(f)(x))
1
2
3
4
5
6
7
8
9
10
jax.jacfwd
<class 'jax._src.interpreters.ad.JVPTracer'> [0.  0.5 1. ] VmapTracer<float32[3]>
[[1.         0.         0.        ]
 [0.         0.87758255 0.        ]
 [0.         0.         0.5403023 ]]
My jacfwd
<class 'jax._src.interpreters.ad.JVPTracer'> [0.  0.5 1. ] VmapTracer<float32[3]>
[[1.         0.         0.        ]
 [0.         0.87758255 0.        ]
 [0.         0.         0.5403023 ]]

jax.grad is simply a wrapper around VJP that propagates back a unit output tangent of a scalar function

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def loss(x):
    print("x:", type(x), x.primal, x.tangent)
    return jnp.sum(x * jnp.arange(len(x)))

def custom_grad(f):
    """Emulates jax.grad"""
    def grad_loss(x):
        vjp_func = jax.vjp(loss, x)[1]

        output_cotangent = jnp.array(1.)
        input_cotangent = vjp_func(output_cotangent)[0]
        return input_cotangent

    return grad_loss

print("jax.grad:")
print(jax.grad(loss)(jnp.array([1., 2., 3.])))

print("custom_grad:")
print(custom_grad(loss)(jnp.array([1., 2., 3.])))
1
2
3
4
5
6
jax.grad:
x: <class 'jax._src.interpreters.ad.LinearizeTracer'> [1. 2. 3.] JitTracer<float32[3]>
[0. 1. 2.]
custom_grad:
x: <class 'jax._src.interpreters.ad.LinearizeTracer'> [1. 2. 3.] JitTracer<float32[3]>
[0. 1. 2.]

It should be easy for you to imagine how other wrappers like jax.value_and_grad operate. Higher order derivatives like jax.hessian are likely implemented through recursive / nested tracing (where tracer values are tracers themselves). However, we won’t go that deep here.

Custom JVPs

Since it is relatively easy to efficiently compose most derivative operations from JVPs and VJPs, we will focus for the rest of this investigation on how these operate and how to customize and optimize them.

Since JVPs are fundamentally simpler, we will start with these.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@jax.custom_jvp
def f(x):
    print("Calling f: type(x): ", type(x))
    return jnp.sin(x)

def f_jvp(primals, tangents):
    print("Calling JVP: type(x): ", type(primals[0]))
    x, x_tangent = primals[0], tangents[0]
    return jnp.sin(x), x_tangent * jnp.cos(x)

f.defjvp(f_jvp)

x = jnp.array([0., jnp.pi/2, jnp.pi])
x_tangent = jnp.array([1., 1., 1.])

print("Evaluating JVP:")
out_val, out_tangent = jax.jvp(f, (x,), (x_tangent,))

print("\nEvaluating f:")
out = f(x)

def function_using_f(x):
    print("A function calling f: type(x):", type(x))
    y = f(x)
    print("type(y):", type(y))
    return y

print("\nEvaluating function_using_f with JVP:")
out_val, out_tangent = jax.jvp(function_using_f, (x,), (x_tangent,))
1
2
3
4
5
6
7
8
9
10
Evaluating JVP:
Calling JVP: type(x):  <class 'jaxlib._jax.ArrayImpl'>

Evaluating f:
Calling f: type(x):  <class 'jaxlib._jax.ArrayImpl'>

Evaluating function_using_f with JVP:
A function calling f: type(x): <class 'jax._src.interpreters.ad.JVPTracer'>
Calling JVP: type(x):  <class 'jaxlib._jax.ArrayImpl'>
type(y): <class 'jax._src.interpreters.ad.JVPTracer'>

A few points are worth noting here:

  • We had to define two functions for the JVP. One that gets evaluated when using f normally and one that is used to evaluate both f and its JVP at once.
  • When evaluating the JVP only the latter function gets called (and the former is not called at all). This interface makes sense, since there may be some parts of the evaluation of f and its JVP that can be reused, so it is often better to define a single function to do both.
  • Both functions get always called with values (and not tracers) and return values (not tracers)
  • Inside of “function_using_f” x is a tracer.
  • We infer that the custom_jvp decorator handles the interface. If it gets called with an AD-tracer, it calls f_jvp with the primals/tangents of the tracer and returns the outputs as AD-tracers. If it gets called with a normal array, it simply calls f.

Jax bundles all inputs to the JVP into tuples and provides all non-differentiable inputs separately. The interface may be slightly awkward at times, especially with keyword arguments. Consider the following example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from functools import partial

@partial(jax.custom_jvp, nondiff_argnames=('a', 'flag'))
def f(x, a=2, flag=True):
    if flag:
        return x**a
    else:
        return x**-a

def f_jvp(a, flag, primals, tangents):
    x = primals[0]
    x_tangent = tangents[0]
    if flag:
        primal_out = x**a
        tangent_out = a * x**(a-1) * x_tangent
    else:
        primal_out = x**-a
        tangent_out = -a * x**(-a-1) * x_tangent
    return primal_out, tangent_out

f.defjvp(f_jvp)

def g(x):
    return f(x, 3, flag=True) + f(x, 3, flag=False)

jax.jvp(f, (jnp.array([2.,3.,4.]), ), (jnp.array([1.,1.,1.]),))
1
(Array([ 4.,  9., 16.], dtype=float32), Array([4., 6., 8.], dtype=float32))

This example works, but it is very simple to mess this up. Note that all the keyword arguments have to be flat arguments in the JVP function. If you wanted something slightly different… e.g. a differentiable keyword argument or a non-differentiable positional argument… feel free to try it… and good luck with that!

To make things simpler, I found it useful to define a nested function with the inner function only containing differentiable arguments. For the example above, we could write the following:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def f(x, a=2, flag=True):
    def f_impl(x):
        if flag:
            return x**a
        else:
            return x**-a

    def f_jvp(primals, tangents):
        x = primals[0]
        x_tangent = tangents[0]
        if flag:
            primal_out = x**a
            tangent_out = a * x**(a-1) * x_tangent
        else:
            primal_out = x**-a
            tangent_out = -a * x**(-a-1) * x_tangent
        return primal_out, tangent_out
    
    f_impl = jax.custom_jvp(f_impl)
    f_impl.defjvp(f_jvp)
    
    return f_impl(x)

def g(x):
    return f(x, 3, flag=True) + f(x, 3, flag=False)

jax.jvp(f, (jnp.array([2.,3.,4.]), ), (jnp.array([1.,1.,1.]),))
1
(Array([ 4.,  9., 16.], dtype=float32), Array([4., 6., 8.], dtype=float32))

While this achieves the same thing, I find it much easier to read. Adding additional arguments is much less scary and optional arguments don’t need to be repeated several times. It has also the additional advantage that only f, but not f_jvp (which we don’t expect anyone to call) is visible at the module level – thus making it for users easier to find the right function. We’ll stick with this pattern for the rest of this tutorial.

As a note of caution I want to mention that it is important to ensure to not take derivatives with respect to one of the other input parameters. If this is done an AD-tracer may leak into the JVP function, causing a tracer error. Anything that may have a derivative should be an argument of the inner function. If you want to call the function with AD-tracers at non-differentiable input locations, explicitly remove their tangents with jax.lax.stop_gradient.

One very useful aspect of defining JVPs is that jax can automatically infer the VJP from them:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def f(x):
    def f_impl(x):
        print("Calling f")
        return jnp.sin(x)
    
    def f_jvp(primals, tangents):
        x, x_tangent = primals[0], tangents[0]
        print("Calling JVP", type(x), type(x_tangent))
        return jnp.sin(x), x_tangent * jnp.cos(x)
    
    f_impl = jax.custom_jvp(f_impl)
    f_impl.defjvp(f_jvp)
    return f_impl(x)

val, cotangent_fun =  jax.vjp(f, jnp.array([0., jnp.pi/2, jnp.pi]))

output_cotangent = jnp.array([1., 1., 1.])

print("Evaluate Tangent:")
input_cotangent = cotangent_fun(output_cotangent)
1
2
Calling JVP <class 'jaxlib._jax.ArrayImpl'> <class 'jax._src.interpreters.partial_eval.JaxprTracer'>
Evaluate Tangent:

Even though we defined only a JVP, jax is also able to use our function in backward differentiation. Apparently jax has replaced the tangent by a tracer that is used to capture the necessary graph for the backwards pass. How does this work? Possibly through a combination of jax.linearize (to define a linear map on the tangent space) and primitive transposition rules (jax.linear_transpose). Such transposition rules are implemented for all linear primitives in jax.

I am not entirely sure about this, but I suspect that implementing a custom_jvp and using jax’s VJP-through-JVP machinery may give close-to-optimal results for a VJP in many cases.

The opposite – JVP through a custom_vjp function is not possible in jax – probably since the function might not be bijective and the VJP function not contain all the needed information.

Custom VJPs

To define a custom_vjp we need to provide three different functions:

  • A function for normal evluation
  • A function that is evaluated in the forward pass of the VJP, returning the function value and residuals for the backward pass
  • The VJP function that gets the residuals and output-space cotangents and returns the input-space cotangents
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def f(x, y):
    def f_impl(x, y):
        print("Calling f")
        return jnp.sin(x) * jnp.cos(x) * y**2
    
    def f_fwd(x, y):
        print("Calling f_fwd")
        Sx, Cx = jnp.sin(x), jnp.cos(x)
        return Sx*Cx * y**2, (Sx, Cx, y)
    
    def f_vjp(residuals, cotangent):
        print("Calling f_vjp")
        Sx, Cx, y = residuals
        dfdx = (Cx**2 - Sx**2) * y**2
        dfdy = Cx * Sx * 2 * y
        return (cotangent * dfdx, cotangent * dfdy)
    
    f_impl = jax.custom_vjp(f_impl)
    f_impl.defvjp(f_fwd, f_vjp)
    return f_impl(x, y)

x = jnp.array([0., jnp.pi/2, jnp.pi])
y = jnp.array([1., 2., 3.])

print("normal eval")
val = f(x, y)
print("\nforward pass")
val, cotangent_fun =  jax.vjp(f, x, y)
print("\nbackward pass")
output_cotangent = jnp.array([1., 1., 1.])
x_tangent = cotangent_fun(output_cotangent)
1
2
3
4
5
6
7
8
normal eval
Calling f

forward pass
Calling f_fwd

backward pass
Calling f_vjp

Note that in this example we save the values of sin(x), cos(x) and y as residuals to be reused in the backward pass so that we can avoid re-evaluating these later.

In practice, it would probably be better to just pass x as a residual here and reevaluate sin(x) and cos(x) later. Passing around data is often more expensive on a GPU than reevaluating it. However, for the sake of this example imagine that sin and cos represent some very expensive calculations that we really don’t want to redo.

You can see that compared to JVPs the main conceptual difference is that our forward function may return residuals that are kept around for the backwards pass. Since residuals may weigh heavily on the memory usage, these may be one of the most important aspects to optimize / minimize when designing good custom_vjp functions.

It is a good idea to verify that our VJP implementation is correct. This can easily be done with jax.test_util.check_grads. It checks this by testing against finite differences:

1
2
from jax.test_util import check_grads
check_grads(f, (x, y), order=1, modes=("rev",))
1
2
3
4
5
Calling f_fwd
Calling f
Calling f
Calling f
Calling f_vjp

It is worth knowing that there exists also an alternative interface for the VJP via jax.custom_gradient. In this case the residuals are passed through a python closure:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@jax.custom_gradient
def f(x, y):
    Sx, Cx = jnp.sin(x), jnp.cos(x)
    
    def f_vjp(cotangent):
        dfdx = (Cx**2 - Sx**2) * y**2
        dfdy = Cx * Sx * 2 * y
        return (cotangent * dfdx, cotangent * dfdy)

    return Sx*Cx * y**2, f_vjp

x = jnp.array([0., jnp.pi/2, jnp.pi])
y = jnp.array([1., 2., 3.])

check_grads(f, (x, y), order=1, modes=("rev",))

This interface is much easier to understand. However, it is important to know that it may be slightly less optimal in some cases:

  • It might be necessary to keep more temporary variables alive
  • Unlike the default interface it does not encourage thinking about residuals carefully
  • In some scenarios, we might actually want a different way to evaluate the forward pass when needing gradients

Many of these drawbacks may be mitigated by jax.jit.

In practice I’d recommend to start with the jax.custom_gradient interface when implementing a VJP. Once you have verified the implementation, you may switch to the jax.custom_vjp interface if you think there may be benefits through the more fine-grained control. This way you can first focus on the math and later on the slightly awkward VJP interface.

Residuals & Checkpointing

When we chain several functions together and each stores a residual value, the memory usage can get quite high very quickly:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def get_used_bytes():
    return jax.devices()[0].memory_stats()["bytes_in_use"]

x = jnp.zeros(1024*256)  # ~ 1MB

bytes_start = get_used_bytes()

verbose = True
def print_bytes_diff():
    if verbose:
        bytes_diff = get_used_bytes() - bytes_start
        print(f"MB used: {bytes_diff / 1024/1024:.1f}")

def f(x):
    def f_impl(x):
        return x*x
    
    def f_fwd(x):
        print_bytes_diff()
        return x*x, x
    
    def f_vjp(x_res, cotangent):
        print_bytes_diff()
        return (cotangent * 2*x_res,)
    
    f_impl = jax.custom_vjp(f_impl)
    f_impl.defvjp(f_fwd, f_vjp)

    return f_impl(x)


def g(x):
    x = f(x)
    x = f(x)
    x = f(x)
    x = f(x)
    x = f(x)
    return x

print("forward:")
y, fvjp = jax.vjp(g, x)

print("\nbackward:")
x_cotangent = fvjp(jnp.ones_like(x))

print("\nresiduals:")
verbose = False
jax.ad_checkpoint.print_saved_residuals(g, x)

del x, y, fvjp, x_cotangent  # to make results consistent when re-running
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
forward:
MB used: 0.0
MB used: 1.0
MB used: 2.0
MB used: 3.0
MB used: 4.0

backward:
MB used: 6.0
MB used: 7.0
MB used: 7.0
MB used: 7.0
MB used: 7.0

residuals:
f32[262144] from the argument x
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)
f32[262144] output of mul from /tmp/ipykernel_840435/435038629.py:20:15 (f.<locals>.f_fwd)

The memory usage increases step by step during the forward pass, since more and more residuals get accumulated. As mentioned earlier, we can see all of the residuals by inspecting them with jax.ad_checkpoint.print_saved_residuals.

Using jax.jit can help eliminate many residuals. We will discuss this later.

It is possible to avoid storing residuals and instead instruct jax to recalculate them during the backward pass by using jax.checkpoint:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = jnp.zeros(1024*256)  # ~ 1MB
bytes_start = get_used_bytes()

verbose = True
print("forward:")
y, fvjp = jax.vjp(jax.checkpoint(g), x)

print("\nbackward:")
x_cotangent = fvjp(jnp.ones_like(x))

print("\nresiduals:")
verbose = False
jax.ad_checkpoint.print_saved_residuals(jax.checkpoint(g), x)

del x, y, fvjp, x_cotangent  # to make results consistent when re-running
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
forward:
MB used: 0.0
MB used: 0.0
MB used: 0.0
MB used: 0.0
MB used: 0.0

backward:
MB used: 2.0
MB used: 2.0
MB used: 2.0
MB used: 2.0
MB used: 2.0

residuals:
f32[262144] from the argument x

Checkpointing is a quite sophisticated feature with many options. Consider having a look at the documentation of jax.checkpoint and this in-depth-guide.

Note that forward differentiation trivially does not exhibit a growth in memory requirements, since no residuals are needed:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def f(x):
    def f_impl(x):
        print_bytes_diff()
        return x*x
    
    def f_jvp(primals, tangents):
        x, tangent = primals[0], tangents[0]
        print_bytes_diff()
        return x*x, tangent*2*x
    
    f_impl = jax.custom_jvp(f_impl)
    f_impl.defjvp(f_jvp)

    return f_impl(x)

def g(x):
    x = f(x)
    x = f(x)
    x = f(x)
    x = f(x)
    x = f(x)
    return x

x = jnp.zeros(1024*256)  # ~ 1MB
bytes_start = get_used_bytes()

verbose = True
y, fjvp = jax.jvp(g, (x,), (jnp.ones_like(x),))
1
2
3
4
5
MB used: 1.0
MB used: 3.0
MB used: 3.0
MB used: 3.0
MB used: 3.0

Gradients and jax.jit

jax.jit can be quite good at eliminating residuals and fusing gradient computation and function evaluation into very few kernels. Consider e.g. our example from before:

1
2
3
4
5
6
7
8
9
def f(x):
    return x*x

def g(x):
    return jnp.sum(f(f(f(f(f(x))))))

x = jnp.zeros(1024*256)  # ~ 1MB

show_hlo_info(jax.jit(jax.value_and_grad(g)), x)
1
2
3
4
5
6
7
8
--------  Memory usage of g  ---------
const : 8 B
code  : 9.7 kB
temp  : 2.0 kB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 2.0 MB

svg

However, things get worse when evaluating slightly more expensive functions (consider the temporary memory). In this case jax.checkpoint may help:

1
2
3
4
5
6
7
8
9
10
11
12
13
def f(x):
    return jnp.sin(jnp.cos(x*x))

def g(x):
    return jnp.sum(f(f(f(f(f(x))))))

x = jnp.zeros(1024*256)  # ~ 1MB

show_hlo_info(jax.jit(jax.value_and_grad(g)), x, mode="mem")

print("With checkpointing:")

show_hlo_info(jax.jit(jax.value_and_grad(jax.checkpoint(g))), x, mode="mem")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
--------  Memory usage of g  ---------
const : 8 B
code  : 114.2 kB
temp  : 8.0 MB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 10.0 MB
With checkpointing:
--------  Memory usage of g  ---------
const : 8 B
code  : 130.9 kB
temp  : 16 B
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 2.0 MB

Loops

Jax’s backward autodiff tends to have horrendous memory requirements when differentiating loops – almost always storing all intermediate values.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def f(i, x):
    return x*x + 1

def g(x):
    xfin = jax.lax.fori_loop(0, 20, f, x)
    return jnp.sum(xfin)

def g_chunked(x):
    def g5loops(i, x):
        xfin = jax.lax.fori_loop(i, i+5, f, x)
        return xfin
    g5loops = jax.checkpoint(g5loops, static_argnums=0)
    
    for i in range(4):
        x = g5loops(i*5, x)

    return jnp.sum(x)

x = jnp.ones((1024,256)) # ~ 1MB

show_hlo_info(jax.jit(jax.grad(g)), x, mode="mem")  # Memory usage grows with nloops

show_hlo_info(jax.jit(jax.grad(jax.checkpoint(g))), x, mode="mem") # Checkpointing doesn't help here (why?)

show_hlo_info(jax.jit(jax.grad(g_chunked)), x, mode="mem") # Chunking also doesn't help (why?) ## I guess jit messed it up?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
--------  Memory usage of g  ---------
const : 12 B
code  : 12.0 kB
temp  : 20.0 MB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 22.0 MB
--------  Memory usage of g  ---------
const : 12 B
code  : 12.5 kB
temp  : 20.0 MB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 22.0 MB
--------  Memory usage of g_chunked  ---------
const : 48 B
code  : 12.2 kB
temp  : 24.0 MB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 26.0 MB

I didn’t find a way to use checkpointing to avoid this. Please let me know if you know how to do this!

On the other hand JVPs work fine with differentiating loops – since it is trivial to consume intermediate data immediately:

(This doesn’t mean however, that we can replace the gradient by a forward Jacobian in this case)

1
2
3
4
5
6
def jvp_g(x):
    return jax.jvp(g, (x,), (jnp.ones_like(x),))

show_hlo_info(jax.jit(jvp_g), x, mode="mem")

show_hlo_info(jax.jit(jax.jacfwd(g)), x, mode="mem") # explodes badly
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
--------  Memory usage of jvp_g  ---------
const : 16 B
code  : 13.0 kB
temp  : 2.0 MB
arg   : 1.0 MB
output: 24 B
alias : 0 B
peak  : 3.0 MB
--------  Memory usage of g  ---------
const : 8 B
code  : 15.2 kB
temp  : 256.5 GB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 256.5 GB


W1208 17:13:29.916063  840435 hlo_rematerialization.cc:3204] Can't reduce memory use below 5.70GiB (6116022681 bytes) by rematerialization; only reduced to 256.50GiB (275415826432 bytes), down from 256.50GiB (275415826432 bytes) originally

For the scenario at hand, the loop is reversible. Therefore, we can solve the problem by defining a custom_vjp that reconstructs the intermediate values one-by-one, starting from the endpoint:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def reversible_loop_with_vjp(body_fwd, body_bwd, init_val, nsteps):
    """Similar to jax.lax.fori_loop, but using a reversible approach in the VJP to save memory.
    
    body_fwd: the normal for loop body
    body_bwd: loop body that reverses body_fwd
    """
    @jax.custom_vjp
    def f(init_val):
        return jax.lax.fori_loop(0, nsteps, body_fwd, init_val)
    
    def f_fwd(init_val):
        final_val = f(init_val)
        return final_val, final_val
    
    def f_bwd(final_val, cotangent):
        def bwd_step(i, carry):
            val, cot = carry
            prev_val = body_bwd(nsteps-i-1, val)
            prev_cot = jax.vjp(body_fwd, nsteps-i-1, prev_val)[1](cot)[1]
            return prev_val, prev_cot

        initial_val, initial_cot = jax.lax.fori_loop(0, nsteps, bwd_step, (final_val, cotangent))
        return initial_cot,
        
    f.defvjp(f_fwd, f_bwd)

    return f(init_val)

def step_fwd(i, x):
    return x*x + 1e-5
def step_bwd(i, x):
    return jnp.sqrt(x - 1e-5)

def g(x):
    xfin = reversible_loop_with_vjp(step_fwd, step_bwd, x, 5)
    return jnp.mean(xfin)

x = jnp.ones((1024,256))*0.1 # ~ 1MB

jax.test_util.check_grads(g, (x,), order=1, modes=("rev",))

show_hlo_info(jax.jit(jax.grad(g)), x, mode="mem")
1
2
3
4
5
6
7
8
--------  Memory usage of g  ---------
const : 16 B
code  : 10.5 kB
temp  : 1.0 MB
arg   : 1.0 MB
output: 1.0 MB
alias : 0 B
peak  : 3.0 MB

jax.linear_transpose

If you ever get confused during writing a VJP about what the adjoint of some operation may be, it is useful to know about jax.linear_transpose. You can experiment with it to understand what the adjoint of many linear operations is:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
print("adjoint of sum (= outer product with ones)")
def f(x):
    return jnp.sum(x)

x = jnp.linspace(0.,1., 10)
fT = jax.linear_transpose(f, x)
print(fT(2.)[0])

print("adjoint of gather (= scatter)")

isort = jnp.argsort(x)
def f(x):
    return x[isort]

x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
fT = jax.linear_transpose(f, x)
xsort = f(x)

print(fT(xsort)[0])
print(jnp.zeros_like(x).at[isort].add(xsort))  # gives the same
1
2
3
4
5
6
7
adjoint of sum (= outer product with ones)
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
adjoint of gather (= scatter)
[0.947667   0.9785799  0.33229148 0.46866846 0.5698887  0.16550303
 0.3101946  0.68948054 0.74676657 0.17101455]
[0.947667   0.9785799  0.33229148 0.46866846 0.5698887  0.16550303
 0.3101946  0.68948054 0.74676657 0.17101455]

Another use-case may be when writing custom_vjps to transpose some linear sub-computations with jax’s built-in system.

Some applications

Direct force summation

We want to solve gravitational forces through direct summation. For this I have written a function that processes particles in chunks below to avoid creating an $N^2$ matrix in memory

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def direct_force_scan(x, m=1., n2lim=1e8, eps=1e-2):
    """Implements a direct N-body force calculation using jax.lax.scan to limit memory usage."""
    N = x.shape[0]

    nmax = int(np.ceil(n2lim / len(x)))
    nev = int(np.ceil(x.shape[0] / nmax))

    def force_segment(_, i):
        isel = jnp.arange(nmax, dtype=jnp.int32) + i*nmax
        dx = x - x[isel,None]
        
        r_ij2 = jnp.sum(dx**2, axis=-1)
        rinv =  1./jnp.sqrt(r_ij2 + eps**2)

        return None, - jnp.sum(dx * rinv[...,None]**3 * m, axis=1)

    _, forces = jax.lax.scan(force_segment, None, jnp.arange(nev, dtype=jnp.int32))

    return jnp.concatenate(forces)[0:N]

xpart = jax.random.normal(jax.random.PRNGKey(0), (1024*32, 3))
fpart = direct_force_scan(xpart)
print(fpart.shape)
1
(32768, 3)

While it is possible to evaluate the JVP very well with autodiff, the VJP fails miserably:

1
2
3
4
5
6
7
8
9
10
11
# JVP works fine:
val, tang = jax.jvp(direct_force_scan, (xpart,), (jnp.ones_like(xpart),))

def fjvp(x):
    return jax.jvp(direct_force_scan, (x,), (jnp.ones_like(x),))
def fgrad(x):
    return jax.grad(lambda x: jnp.sum(direct_force_scan(x)))(x)

show_hlo_info(jax.jit(fjvp), xpart, mode="mem") # JVP works fine

show_hlo_info(jax.jit(fgrad), xpart, mode="mem") # VJP consumes 53GB of memory!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
--------  Memory usage of fjvp  ---------
const : 64 B
code  : 25.9 kB
temp  : 772.8 MB
arg   : 384.0 kB
output: 768.0 kB
alias : 0 B
peak  : 773.9 MB


W1208 17:13:32.239346  840435 hlo_rematerialization.cc:3204] Can't reduce memory use below 5.70GiB (6116612505 bytes) by rematerialization; only reduced to 53.28GiB (57205066996 bytes), down from 53.28GiB (57205066996 bytes) originally


--------  Memory usage of fgrad  ---------
const : 100 B
code  : 68.8 kB
temp  : 53.3 GB
arg   : 384.0 kB
output: 384.0 kB
alias : 0 B
peak  : 53.3 GB

Since the VJP fails so badly, we should consider implementing a custom_vjp. We could do this by writing a custom-loop that step-by-step adds together the cotangents.

However, there is a much easier way out here. The Jacobian of the force is symmetric:

\(\begin{align} \nabla_{x_k} \vec{F}_i &= - \sum_{i \neq j} \nabla_{x_k} \nabla_{x_i} \phi(\vec{x}_i - \vec{x}_j) \nonumber \\ &= \sum_{i \neq j} \delta_{ik} {T}_{ij} + \delta_{jk} {T}_{ij} \nonumber \\ &= \nabla_{x_i} \vec{F}_k \end{align}\) with the tidal tensor ${T}_{ij} = -\nabla \nabla \phi (\vec{x} = \vec{x}_i - \vec{x}_j) $ and $\phi$ as the Green’s function of the gravitational potential. Therefore, VJP and JVP are identical: \(\begin{align} \sum_i g_i \frac{\partial F_i}{\partial \vec{x}_k} = \sum_i \frac{\partial F_k}{\partial \vec{x}_i} g_i \end{align}\) We may therefore simply replace the VJP by an invocation of the JVP – which we know that jax handles efficiently:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def direct_force_scan_with_vjp(x, m=1., n2lim=1e8, eps=1e-1):
    def f_impl(x):
        return direct_force_scan(x, m, n2lim=n2lim, eps=eps)
    
    @jax.custom_vjp
    def f(x):
        return f_impl(x)
    def ffwd(x):
        return f_impl(x), x
    def fbwd(x_res, cotangent):
        return jax.jvp(f_impl, (x_res,), (cotangent,))[1],
    f.defvjp(ffwd, fbwd)

    return f(x)

def loss(x):
    return jnp.sum(direct_force_scan_with_vjp(x)**2)

from jax.test_util import check_grads
check_grads(jax.jit(loss), (xpart,), order=1, modes=("rev",), eps=1e-2, rtol=1e-1)

show_hlo_info(jax.jit(jax.grad(loss)), xpart, mode="mem")
1
2
3
4
5
6
7
8
--------  Memory usage of loss  ---------
const : 80 B
code  : 35.2 kB
temp  : 768.3 MB
arg   : 384.0 kB
output: 384.0 kB
alias : 0 B
peak  : 769.0 MB

We conclude that it can be viable to replace VJPs by JVPs when we have the option – because JVPs tend to work much better by default!

N-body simulation

The Jacobian of the mapping from initial to final positions/velocities in an N-body simulation is symplectic. We can use the symplectic property to relate the VJP to the JVP of the inverse. All it requires is rotating position and velocity cotangents at the beginning and the end of the backward pass.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def nbody_sim(pos, vel, dt, nsteps):
    def run_impl(pos, vel, dt):
        def kick_drift_kick_step(i, state):
            x, v = state
            v = v + direct_force_scan(x) * 0.5*dt
            x = x + v * dt
            v = v + direct_force_scan(x) * 0.5*dt
            return x, v
        
        return jax.lax.fori_loop(0, nsteps, kick_drift_kick_step, (pos, vel))
    
    @jax.custom_vjp
    def run(pos, vel):
        return run_impl(pos, vel, dt)
    
    def run_fwd(pos, vel):
        pos_vel = run_impl(pos, vel, dt)
        return pos_vel, pos_vel
    
    def run_bwd(pos_vel, cot_pos_vel):
        pos, vel = pos_vel
        cpos, cvel = cot_pos_vel
        pos_vel0, (cpos0, cvel0) = jax.jvp(run_impl, (pos, vel, -dt), (-cvel, cpos, 0.))
        return (cvel0, -cpos0)
    
    run.defvjp(run_fwd, run_bwd)
    return run(pos, vel)

nbody_sim.jit = jax.jit(nbody_sim)

xpart = jax.random.normal(jax.random.PRNGKey(0), (1024*32, 3))
vpart = jax.random.normal(jax.random.PRNGKey(1), (1024*32, 3))

xfinal, vfinal = nbody_sim.jit(xpart, vpart, dt=1e-3, nsteps=10)

def loss(x, v):
    xfinal, vfinal = nbody_sim(x, v, dt=1e-3, nsteps=10)
    return jnp.mean(xfinal**2) + jnp.mean(vfinal**2)

grad = jax.jit(jax.grad(loss))(xpart, vpart)

check_grads(jax.jit(loss), (xpart, vpart), order=1, modes=("rev",), eps=1e-3, rtol=1e-1)

show_hlo_info(jax.jit(jax.grad(loss)), xpart, vpart, mode="mem")
1
2
3
4
5
6
7
8
--------  Memory usage of loss  ---------
const : 236 B
code  : 53.1 kB
temp  : 774.7 MB
arg   : 768.0 kB
output: 384.0 kB
alias : 0 B
peak  : 775.8 MB

Note that we didn’t even need the VJP of the force here (since we are implicitly requesting a JVP)!

Resources

This post is licensed under CC BY 4.0 by the author.