Installation

Via pip

The easiest way of installing is via pip. If you have jax[cuda12], simply use

pip install jztree[cuda12]

and for jax[cuda13]:

pip install jztree[cuda13]

The CUDA 13 wheel supports python 3.11-3.14 and the CUDA 12 wheel supports python 3.11-3.13. If you are outside of this range, you may still have success by building from sources, but be aware that jax also has a limited compatibility range.

Build from sources

First of all, clone the repository:

git clone https://github.com/jstuecker/jz-tree

Then you need to check whether your GPU supports CUDA13 or CUDA12 (older CUDA versions are not supported by jax). To install with CUDA13 you need

Take note of your compute capability. You can significantly speed up the build time by setting the CUDAARCHS environment variable, e.g.

export CUDAARCHS=87

for compute capability 8.7. By default we use CUDAARCHS="all" to build for all architectures. This may taking a very long time (20-30 minutes rather than 2). You may also provide CUDAARCHS="native" to automatically detect your systems architecture.

CUDA13 installation

The simplest way to install with CUDA13 is via pip. First, install the build dependencies:

pip install jax[cuda13] scikit-build-core nanobind cmake>=3.24 setuptools_scm

Finally, install jz-tree with --no-build-isolation

pip install -e . --no-build-isolation

Note

If you do an editable installation without --no-build-isolation, you python may have problems to locate the CUDA modules.

(Note: Installation speed my be significantly higher with uv pip ) –>

CUDA12 installation

To build with CUDA12 independently of system installations, we require a conda distribution, since the nvidia-cuda-nvcc-cu12 pip package does not ship the nvcc compiler binary. However, we can install it with a conda package.

Install miniforge (or any other conda distribution) / setup an environment / activate it (skip steps as appropriate)

curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-Linux-x86_64.sh -b
rm Miniforge3-Linux-x86_64.sh
eval "$(~/miniforge3/bin/conda shell.bash hook)" # possibly fill in the correct directory
conda init
conda create --name jzenv -y
conda activate jzenv

Install prequisites via conda and pip:

conda install -c conda-forge pip cuda-nvcc cuda-version=12 cudnn nccl libcufft cuda-cupti libcublas libcusparse
pip install scikit-build-core nanobind cmake>=3.24 setuptools_scm
pip install --upgrade "jax[cuda12-local]"

Finally, install the code with

pip install -e . --no-build-isolation

or with the [dev] optional dependencies if you’d like to use unit tests and some optional features.

System CUDA setup:

I assume that you have already installed or loaded some CUDA version. E.g. check with

nvcc --version

Install:

pip install --upgrade "jax[cuda13-local]"   # or jax[cuda12-local]
pip install scikit-build-core nanobind cmake>=3.24 setuptools_scm
pip install -e . --no-build-isolation

Note that you may run into troubles if you have installed a second CUDA version in your python environment.

Speeding up build-time

As mentioned earlier, the primary way to speed up build is to provide the CUDAARCHS environment variable. For example,

CUDAARCHS=87 pip install -e . --no-build-isolation

A more advanced way of reducing the build time is to reduce the number of template variants that are instanced, by modifying the code generation script src/_generate_ffi.py (and executing it again). This is explained in more detail in CUDA kernels and automatic FFI generation.

Hello World

You can verify that the installation was succesful by running

python checks/hello_world.py

This should give you something like this:

rnn: [[0.         0.00327169 0.00362817 0.00418469 0.00620413 0.00629311
  0.00657683 0.0065819 ]
 [0.         0.0018539  0.00218362 0.00325193 0.00418783 0.00457177
  0.00464585 0.00483315]
 [0.         0.00392929 0.00410999 0.00522736 0.00623543 0.00679859
  0.006818   0.00706907]]
Should be:
[[0.         0.00327169 0.00362817 0.00418469 0.00620413 0.00629311
  0.00657683 0.0065819 ]
 [0.         0.0018539  0.00218362 0.00325193 0.00418783 0.00457177
  0.00464585 0.00483315]
 [0.         0.00392929 0.00410999 0.00522736 0.00623543 0.00679859
  0.006818   0.00706907]]

If you get some warning like this in the beginning

E0406 01:02:01.239070 1581469 cuda_executor.cc:1743] Could not get kernel mode driver version: [INVALID_ARGUMENT: Version does not match the format X.Y.Z]

don’t worry, it’s a new feature in jax and can be savely ignored or silenced by defining the environment variable TF_CPP_MIN_LOG_LEVEL=3.