Installation
Via pip
The easiest way of installing is via pip. If you have jax[cuda12], simply use
pip install jztree[cuda12]
and for jax[cuda13]:
pip install jztree[cuda13]
The CUDA 13 wheel supports python 3.11-3.14 and the CUDA 12 wheel supports python 3.11-3.13. If you are outside of this range, you may still have success by building from sources, but be aware that jax also has a limited compatibility range.
Build from sources
First of all, clone the repository:
git clone https://github.com/jstuecker/jz-tree
Then you need to check whether your GPU supports CUDA13 or CUDA12 (older CUDA versions are not supported by jax). To install with CUDA13 you need
A GPU with compute capability >=7.5.
A sufficiently new graphics driver >= 580. You can check your GPU and your driver with
nvidia-smi
Take note of your compute capability. You can significantly speed up the build time
by setting the CUDAARCHS environment variable, e.g.
export CUDAARCHS=87
for compute capability 8.7. By default we use CUDAARCHS="all" to build for all architectures. This
may taking a very long time (20-30 minutes rather than 2). You may also provide CUDAARCHS="native" to automatically detect your systems architecture.
CUDA13 installation
The simplest way to install with CUDA13 is via pip. First, install the build dependencies:
pip install jax[cuda13] scikit-build-core nanobind cmake>=3.24 setuptools_scm
Finally, install jz-tree with --no-build-isolation
pip install -e . --no-build-isolation
Note
If you do an editable installation without --no-build-isolation, you python may have problems to
locate the CUDA modules.
(Note: Installation speed my be significantly higher with uv pip ) –>
CUDA12 installation
To build with CUDA12 independently of system installations, we require a conda distribution, since
the nvidia-cuda-nvcc-cu12 pip package does not ship the nvcc compiler binary. However,
we can install it with a conda package.
Install miniforge (or any other conda distribution) / setup an environment / activate it (skip steps as appropriate)
curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-Linux-x86_64.sh -b
rm Miniforge3-Linux-x86_64.sh
eval "$(~/miniforge3/bin/conda shell.bash hook)" # possibly fill in the correct directory
conda init
conda create --name jzenv -y
conda activate jzenv
Install prequisites via conda and pip:
conda install -c conda-forge pip cuda-nvcc cuda-version=12 cudnn nccl libcufft cuda-cupti libcublas libcusparse
pip install scikit-build-core nanobind cmake>=3.24 setuptools_scm
pip install --upgrade "jax[cuda12-local]"
Finally, install the code with
pip install -e . --no-build-isolation
or with the [dev] optional dependencies if you’d like to use unit tests and some optional features.
System CUDA setup:
I assume that you have already installed or loaded some CUDA version. E.g. check with
nvcc --version
Install:
pip install --upgrade "jax[cuda13-local]" # or jax[cuda12-local]
pip install scikit-build-core nanobind cmake>=3.24 setuptools_scm
pip install -e . --no-build-isolation
Note that you may run into troubles if you have installed a second CUDA version in your python environment.
Speeding up build-time
As mentioned earlier, the primary way to speed up build is to provide the CUDAARCHS environment variable. For example,
CUDAARCHS=87 pip install -e . --no-build-isolation
A more advanced way of reducing the build time is to reduce the number of template variants that are
instanced, by modifying the code generation script src/_generate_ffi.py (and executing it again).
This is explained in more detail in CUDA kernels and automatic FFI generation.
Hello World
You can verify that the installation was succesful by running
python checks/hello_world.py
This should give you something like this:
rnn: [[0. 0.00327169 0.00362817 0.00418469 0.00620413 0.00629311
0.00657683 0.0065819 ]
[0. 0.0018539 0.00218362 0.00325193 0.00418783 0.00457177
0.00464585 0.00483315]
[0. 0.00392929 0.00410999 0.00522736 0.00623543 0.00679859
0.006818 0.00706907]]
Should be:
[[0. 0.00327169 0.00362817 0.00418469 0.00620413 0.00629311
0.00657683 0.0065819 ]
[0. 0.0018539 0.00218362 0.00325193 0.00418783 0.00457177
0.00464585 0.00483315]
[0. 0.00392929 0.00410999 0.00522736 0.00623543 0.00679859
0.006818 0.00706907]]
If you get some warning like this in the beginning
E0406 01:02:01.239070 1581469 cuda_executor.cc:1743] Could not get kernel mode driver version: [INVALID_ARGUMENT: Version does not match the format X.Y.Z]
don’t worry, it’s a new feature in jax and can be savely ignored or silenced by defining the environment variable
TF_CPP_MIN_LOG_LEVEL=3.