安装jaxlib
操作步骤
- 使用PuTTY工具,以root用户登录服务器。
- 执行以下命令安装jax。
pip install jax
- 执行以下命令下载jaxlib源码。
wget https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.70.tar.gz
- 执行以下命令安装jaxlib依赖。
pip install numpy scipy cython six
- 执行以下命令设置编译器。
export CC=gcc CXX=g++ FC=gfortran
- 执行以下命令编译CUDA版本。
python build/build.py --enable_cuda pip install -e build # installs jaxlib (includes XLA)
- 执行以下命令安装生成的jaxlib.whl文件。
pip install ./disk/jaxlib-0.17.0.whl
父主题: 配置编译环境