Rate This Document
Findability
Accuracy
Completeness
Readability

Installing jaxlib

Procedure

  1. Use PuTTY to log in to the server as the root user.
  2. Install jax.
    pip install jax
  3. Download the jaxlib source code.
    wget https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.70.tar.gz
  4. Install the jaxlib dependencies.
    pip install numpy scipy cython six
  5. Configure the compiler.
    export CC=gcc CXX=g++ FC=gfortran
  6. Compile the CUDA version.
    python build/build.py --enable_cuda
    pip install -e build  # installs jaxlib (includes XLA)
  7. Install jaxlib.whl.
    pip install ./disk/jaxlib-0.17.0.whl