开发者
资源
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

安装与验证

  1. 安装TensorFlow 2.21.0 GPU版本。
    1
    python3 -m pip install "tensorflow[and-cuda]==2.21.0"
    
  2. 配置NVIDIA pip包安装的CUDA运行库路径。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    export LD_LIBRARY_PATH=/usr/lib64:\
    /usr/local/lib/python3.11/site-packages/nvidia/cublas/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cuda_cupti/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cuda_runtime/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cufft/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/curand/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cusolver/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/cusparse/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/nccl/lib:\
    /usr/local/lib/python3.11/site-packages/nvidia/nvjitlink/lib:${LD_LIBRARY_PATH:-}
    
  3. 执行以下命令验证TensorFlow版本、GPU设备可见性和GPU张量计算。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    python3 - <<'PY'
    import tensorflow as tf
    
    print("tensorflow_version=" + tf.__version__)
    print("physical_gpus=" + repr(tf.config.list_physical_devices("GPU")))
    print("built_with_cuda=" + str(tf.test.is_built_with_cuda()))
    assert tf.__version__ == "2.21.0"
    assert tf.config.list_physical_devices("GPU") 
    assert tf.test.is_built_with_cuda()  
    with tf.device("/GPU:0"):     
        a = tf.constant([[1.0, 2.0], [3.0, 4.0]])     
        b = tf.constant([[1.0, 2.0], [3.0, 4.0]])     
        c = tf.matmul(a, b)  
    
    print("tensorflow_gpu_matmul=" + str(c.numpy().tolist())) 
    PY
    

    预期输出如下信息。

    1
    2
    3
    4
    tensorflow_version=2.21.0 
    physical_gpus=[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] 
    built_with_cuda=True 
    tensorflow_gpu_matmul=[[7.0, 10.0], [15.0, 22.0]]