安装jaxlib

操作步骤

  1. 使用PuTTY工具,以root用户登录服务器。
  2. 执行以下命令安装jax。

    pip install jax

  3. 执行以下命令下载jaxlib源码。

    wget https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.70.tar.gz

  4. 执行以下命令安装jaxlib依赖。

    pip install numpy scipy cython six

  5. 执行以下命令设置编译器。

    export CC=gcc CXX=g++ FC=gfortran

  6. 执行以下命令编译CUDA版本。

    python build/build.py --enable_cuda
    pip install -e build  # installs jaxlib (includes XLA)

  7. 执行以下命令安装生成的jaxlib.whl文件。

    pip install ./disk/jaxlib-0.17.0.whl