pip install jax
wget https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.70.tar.gz
pip install numpy scipy cython six
export CC=gcc CXX=g++ FC=gfortran
python build/build.py --enable_cuda pip install -e build # installs jaxlib (includes XLA)
pip install ./disk/jaxlib-0.17.0.whl