Installing jaxlib
Procedure
- Use PuTTY to log in to the server as the root user.
- Install jax.
pip install jax
- Download the jaxlib source code.
wget https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.70.tar.gz
- Install the jaxlib dependencies.
pip install numpy scipy cython six
- Configure the compiler.
export CC=gcc CXX=g++ FC=gfortran
- Compile the CUDA version.
python build/build.py --enable_cuda pip install -e build # installs jaxlib (includes XLA)
- Install jaxlib.whl.
pip install ./disk/jaxlib-0.17.0.whl
Parent topic: Configuring the Compilation Environment