鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

安装jaxlib

操作步骤

  1. 使用PuTTY工具,以root用户登录服务器。
  2. 执行以下命令安装jax。
    pip install jax
  3. 执行以下命令下载jaxlib源码。
    wget https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.70.tar.gz
  4. 执行以下命令安装jaxlib依赖。
    pip install numpy scipy cython six
  5. 执行以下命令设置编译器。
    export CC=gcc CXX=g++ FC=gfortran
  6. 执行以下命令编译CUDA版本。
    python build/build.py --enable_cuda
    pip install -e build  # installs jaxlib (includes XLA)
  7. 执行以下命令安装生成的jaxlib.whl文件。
    pip install ./disk/jaxlib-0.17.0.whl