获取DLRM模型源码
DLRM模型训练代码采用的是NodLabs的开源仓中的源码,获取源码后需要合入DLRM模型训练代码补丁文件以适配鲲鹏平台。
- 进入DLRM源码规划路径“/path/to/dlrm”。
1
cd /path/to/dlrm
- 配置git网络代理。
1 2 3
git config --global http.sslVerify false git config --global https.sslverify false git config --global http.proxy "http://用户名:密码@代理IP地址:代理端口"
- 下载tensorflow-dlrm源码。
1
git clone https://github.com/NodLabs/tensorflow-dlrm.git
- 进入源码目录。
1
cd tensorflow-dlrm
- 创建并编写DLRM训练代码补丁“train.patch”文件。
- 创建“train.patch”文件。
1
vi train.patch
- 按“i”进入编辑模式,编辑“train.patch”文件,输入以下内容。
diff --git a/dlrm_criteo_gpu.py b/dlrm_criteo_gpu.py index c2dfeac..0a71668 100644 --- a/dlrm_criteo_gpu.py +++ b/dlrm_criteo_gpu.py @@ -5,7 +5,7 @@ from tqdm import tqdm import tensorflow as tf import dataloader -raw_data = dataloader.load_criteo('../dataset/') +raw_data = dataloader.load_criteo('../../dataset/') dim_embed = 4 bottom_mlp_size = [8, 4] top_mlp_size = [128, 64, 1] @@ -71,4 +71,4 @@ for train_iter, batch_data in enumerate(train_dataset): average_loss.reset_states() auc.reset_states() -dlrm_model.save('DLRMModel_tf2_2') +dlrm_model.save_weights('mymodel')
- 按“Esc”键,输入:wq!,按“Enter”保存并退出编辑。
- 创建“train.patch”文件。
- 合入训练代码补丁。
1
git apply train.patch
若命令执行后回显无报错信息提示,证明补丁合入成功。
- 验证补丁正确性。
1
git diff --stat
若回显信息显示文件的修改行数与上图一致,证明补丁合入正确。
父主题: DLRM模型训练