鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

获取DLRM模型源码

DLRM模型训练代码采用的是NodLabs的开源仓中的源码,获取源码后需要合入DLRM模型训练代码补丁文件以适配鲲鹏平台。

  1. 进入DLRM源码规划路径“/path/to/dlrm”。
    1
    cd /path/to/dlrm
    
  2. 配置git网络代理。
    1
    2
    3
    git config --global http.sslVerify false
    git config --global https.sslverify false
    git config --global http.proxy "http://用户名:密码@代理IP地址:代理端口"
    
  3. 下载tensorflow-dlrm源码。
    1
    git clone https://github.com/NodLabs/tensorflow-dlrm.git
    
  4. 进入源码目录。
    1
    cd tensorflow-dlrm
    
  5. 创建并编写DLRM训练代码补丁“train.patch”文件。
    1. 创建“train.patch”文件。
      1
      vi train.patch
      
    2. 按“i”进入编辑模式,编辑“train.patch”文件,输入以下内容。
      diff --git a/dlrm_criteo_gpu.py b/dlrm_criteo_gpu.py
      index c2dfeac..0a71668 100644
      --- a/dlrm_criteo_gpu.py
      +++ b/dlrm_criteo_gpu.py
      @@ -5,7 +5,7 @@ from tqdm import tqdm
       import tensorflow as tf
       import dataloader
      
      -raw_data = dataloader.load_criteo('../dataset/')
      +raw_data = dataloader.load_criteo('../../dataset/')
       dim_embed = 4
       bottom_mlp_size = [8, 4]
       top_mlp_size = [128, 64, 1]
      @@ -71,4 +71,4 @@ for train_iter, batch_data in enumerate(train_dataset):
               average_loss.reset_states()
               auc.reset_states()
      
      -dlrm_model.save('DLRMModel_tf2_2')
      +dlrm_model.save_weights('mymodel')
      
    3. 按“Esc”键,输入:wq!,按“Enter”保存并退出编辑。
  6. 合入训练代码补丁。
    1
    git apply train.patch
    

    若命令执行后回显无报错信息提示,证明补丁合入成功。

  7. 验证补丁正确性。
    1
    git diff --stat
    

    若回显信息显示文件的修改行数与上图一致,证明补丁合入正确。