{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Jax and Backend" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "\n", "sys.path.insert(0, os.path.abspath(\".\"))\n", "sys.path.append(os.path.abspath(\"../../../\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DESC uses JAX for faster execution times with just-in-time (JIT) compilation, automatic differentiation, and other scientific computing tools.\n", "The purpose of ``backend.py`` is to determine whether DESC may take advantage of JAX and GPUs or default to standard ``numpy`` and CPUs. To run DESC on GPU, you should simply have the following code section before you import anything from DESC," ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# from desc import set_device\n", "# set_device(\"gpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can check if it is running on a CPU or GPU with `print_backend_info()`.\n", "This will print the DESC and JAX or NumPy versions, and the device information." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DESC version=0.15.0+36.gf13760717.dirty.\n", "Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.\n", "Using device: CPU, with 17.15 GB available memory.\n" ] } ], "source": [ "from desc.backend import print_backend_info\n", "\n", "print_backend_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "JAX provides a ``numpy`` style API for array operations.\n", "In many cases, to take advantage of JAX, one only needs to replace calls to ``numpy`` with calls to ``jax.numpy``.\n", "A convenient way to do this is with the import statement ``import jax.numpy as jnp``." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from desc.backend import jnp\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 0. 0. 0.]\n", "[0. 0. 0. 0.]\n" ] } ], "source": [ "# give some JAX examples\n", "zeros_jnp = jnp.zeros(4)\n", "zeros_np = np.zeros(4)\n", "\n", "print(zeros_jnp)\n", "print(zeros_np)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Of course if such an import statement is used in DESC, and DESC is run on a machine where JAX is not installed, then a runtime error is thrown.\n", "We would prefer if DESC still works on machines where JAX is not installed.\n", "With that goal, in functions which can benefit from JAX, we use the following import statement: ``from desc.backend import jnp``.\n", "``desc.backend.jnp`` is an alias to ``jax.numpy`` if JAX is installed and ``numpy`` otherwise.\n", "\n", "While ``jax.numpy`` attempts to serve as a drop in replacement for ``numpy``, it imposes some constraints on how the code is written.\n", "For example, ``jax.numpy`` arrays are immutable.\n", "This means in-place updates to elements in arrays is not possible.\n", "To update elements in ``jax.numpy`` arrays, memory needs to be allocated to create a new array with the updated element.\n", "Similarly, JAX's JIT compilation requires control flow structures such as loops and conditionals to be written in a specific way.\n", "\n", "The utility functions in ``desc.backend`` provide a simple interface to perform these operations." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 0. 0. 0.]\n" ] } ], "source": [ "zeros_jnp = jnp.zeros(4)\n", "# this will give an error\n", "# zeros_jnp[0] = 1\n", "# we need to use the at[] method\n", "zeros_jnp = zeros_jnp.at[0].set(1)\n", "print(zeros_jnp)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2. 0. 0. 0.]\n" ] } ], "source": [ "# or to make this compatible with numpy backend we can use the following\n", "from desc.backend import put\n", "\n", "zeros_jnp = put(zeros_jnp, 0, 2)\n", "print(zeros_jnp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `JAX` documentation does a really good job of explaining the similarities and the differences between `jax.numpy` and `numpy`, we won't go too deep here but mention some of the major differences to get you started.\n", "\n", "Technically, most of the operations can be written using `numpy` (as long as it is out of `jax.jit`), but for most of the cases, `jax.numpy` is faster and it can use both CPU and GPUs without any code change. `jax.array`s can live in different devices and also take advantage of efficient implementations of a function depending on the hardware used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is still a good practice to test both versions to see which one is faster (for functions outside of jit). One important point to consider during profiling is to use `block_until_ready()` as explained [here](https://jax.readthedocs.io/en/latest/async_dispatch.html). If you want to specifically use `numpy` version, instead of using `numpy` backend for the whole code, just import `numpy` as usual. There are couple places in the code, we specifically use `numpy` functions. There are different reasons for these, for example, since `jax.array`s are immutable, sometimes they behave unexpectedly in loops, or sometimes `jax.numpy` functions have overhead that makes them slower compared to their `numpy` counterpart for single use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is a plan to remove `numpy` backend since some portions of the code uses `JAX` or related functions which doesn't have other equivalents, and code that relies on the `numpy` backend instead of `JAX` is not automatically tested for correctness by the GitHub CI. Depending on the backend, DESC automatically chooses which method of differentiation to use. If there is no `JAX` installation, it uses finite difference for derivatives. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## JAX Tricks and Tips" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compilation\n", "\n", "Just-in-time compiled functions are great, however, the compilation cost is usually high. That means if you will evaluate the function only couple times, you can be worse off due to the compilation overhead. For the functions you will call many times, the aim is to reduce to number of recompilations. For instance, \n", "\n", "```python\n", "@functools.partial(jax.jit, static_argnums=1)\n", "def fun(x, dx=0):\n", " ...\n", " return ...\n", "```\n", "\n", "will be compiled for every single value of `dx`. Another reason for recompilation can be the argument `x`. The primary reason for most recompilations is the change of array shapes, i.e. `x=jnp.arange(5)` and `jnp.arange(10)` will use a different compiled code. Similarly, `x=jnp.float64(5.0)` and `x=jnp.int64(5)` will also need 2 different compiled code. These type of recompilations are trickier to detect because of the concept of weak/strong types. In a nutshell, one needs to use the same data type to prevent recompilation. For example, `float(3.14)` and `jnp.float64(3.14)` are different in terms of JAX's cache checking. Due to this, in some parts of the code, we have nested calls like `jnp.float64(float(x))`. The first `float(x)` is to make sure that `x` is a scalar value, and the second one is to cast it to a final type.\n", "\n", "JAX stores the compiled binaries in cache (this can be made permanent by [the process explained here](https://desc-docs.readthedocs.io/en/stable/performance_tips.html)), and recompile if there is a cache miss. One of the best ways to detect the recompilation is the context manager,\n", "\n", "```python\n", "with jax.log_compiles():\n", " fun()\n", "```\n", "\n", "One can set the following environment variable to detect the cause of the cache miss. This will give a detailed info on the input arguments and the closest matches in the cache. For example, you can detect data type change by this method. Add this to the top of your code, before using any `JAX` functions.\n", "\n", "```python \n", "jax.config.update('jax_explain_cache_misses', True)\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "cpu", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 2 }