Jax and Backend

[1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

DESC uses JAX for faster execution times with just-in-time (JIT) compilation, automatic differentiation, and other scientific computing tools. The purpose of backend.py is to determine whether DESC may take advantage of JAX and GPUs or default to standard numpy and CPUs. To run DESC on GPU, you should simply have the following code section before you import anything from DESC,

[2]:
# from desc import set_device
# set_device("gpu")

You can check if it is running on a CPU or GPU with print_backend_info(). This will print the DESC and JAX or NumPy versions, and the device information.

[3]:
from desc.backend import print_backend_info

print_backend_info()
DESC version=0.15.0+36.gf13760717.dirty.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: CPU, with 17.15 GB available memory.

JAX provides a numpy style API for array operations. In many cases, to take advantage of JAX, one only needs to replace calls to numpy with calls to jax.numpy. A convenient way to do this is with the import statement import jax.numpy as jnp.

[4]:
from desc.backend import jnp
import numpy as np
[5]:
# give some JAX examples
zeros_jnp = jnp.zeros(4)
zeros_np = np.zeros(4)

print(zeros_jnp)
print(zeros_np)
[0. 0. 0. 0.]
[0. 0. 0. 0.]

Of course if such an import statement is used in DESC, and DESC is run on a machine where JAX is not installed, then a runtime error is thrown. We would prefer if DESC still works on machines where JAX is not installed. With that goal, in functions which can benefit from JAX, we use the following import statement: from desc.backend import jnp. desc.backend.jnp is an alias to jax.numpy if JAX is installed and numpy otherwise.

While jax.numpy attempts to serve as a drop in replacement for numpy, it imposes some constraints on how the code is written. For example, jax.numpy arrays are immutable. This means in-place updates to elements in arrays is not possible. To update elements in jax.numpy arrays, memory needs to be allocated to create a new array with the updated element. Similarly, JAX’s JIT compilation requires control flow structures such as loops and conditionals to be written in a specific way.

The utility functions in desc.backend provide a simple interface to perform these operations.

[6]:
zeros_jnp = jnp.zeros(4)
# this will give an error
# zeros_jnp[0] = 1
# we need to use the at[] method
zeros_jnp = zeros_jnp.at[0].set(1)
print(zeros_jnp)
[1. 0. 0. 0.]
[7]:
# or to make this compatible with numpy backend we can use the following
from desc.backend import put

zeros_jnp = put(zeros_jnp, 0, 2)
print(zeros_jnp)
[2. 0. 0. 0.]

Since JAX documentation does a really good job of explaining the similarities and the differences between jax.numpy and numpy, we won’t go too deep here but mention some of the major differences to get you started.

Technically, most of the operations can be written using numpy (as long as it is out of jax.jit), but for most of the cases, jax.numpy is faster and it can use both CPU and GPUs without any code change. jax.arrays can live in different devices and also take advantage of efficient implementations of a function depending on the hardware used.

It is still a good practice to test both versions to see which one is faster (for functions outside of jit). One important point to consider during profiling is to use block_until_ready() as explained here. If you want to specifically use numpy version, instead of using numpy backend for the whole code, just import numpy as usual. There are couple places in the code, we specifically use numpy functions. There are different reasons for these, for example, since jax.arrays are immutable, sometimes they behave unexpectedly in loops, or sometimes jax.numpy functions have overhead that makes them slower compared to their numpy counterpart for single use.

There is a plan to remove numpy backend since some portions of the code uses JAX or related functions which doesn’t have other equivalents, and code that relies on the numpy backend instead of JAX is not automatically tested for correctness by the GitHub CI. Depending on the backend, DESC automatically chooses which method of differentiation to use. If there is no JAX installation, it uses finite difference for derivatives.

JAX Tricks and Tips

Compilation

Just-in-time compiled functions are great, however, the compilation cost is usually high. That means if you will evaluate the function only couple times, you can be worse off due to the compilation overhead. For the functions you will call many times, the aim is to reduce to number of recompilations. For instance,

@functools.partial(jax.jit, static_argnums=1)
def fun(x, dx=0):
    ...
    return ...

will be compiled for every single value of dx. Another reason for recompilation can be the argument x. The primary reason for most recompilations is the change of array shapes, i.e. x=jnp.arange(5) and jnp.arange(10) will use a different compiled code. Similarly, x=jnp.float64(5.0) and x=jnp.int64(5) will also need 2 different compiled code. These type of recompilations are trickier to detect because of the concept of weak/strong types. In a nutshell, one needs to use the same data type to prevent recompilation. For example, float(3.14) and jnp.float64(3.14) are different in terms of JAX’s cache checking. Due to this, in some parts of the code, we have nested calls like jnp.float64(float(x)). The first float(x) is to make sure that x is a scalar value, and the second one is to cast it to a final type.

JAX stores the compiled binaries in cache (this can be made permanent by the process explained here), and recompile if there is a cache miss. One of the best ways to detect the recompilation is the context manager,

with jax.log_compiles():
    fun()

One can set the following environment variable to detect the cause of the cache miss. This will give a detailed info on the input arguments and the closest matches in the cache. For example, you can detect data type change by this method. Add this to the top of your code, before using any JAX functions.

jax.config.update('jax_explain_cache_misses', True)