RMSNorm
场景说明
使用均方差对tensor做归一化,目前kdnn支持torch.float16类型,其他数据类型会走开源分支。
示例代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch import torch.nn as nn # 启用kdnn torch._C._set_kdnn_enabled(True) # 输入数据: (batch_size, seq_len, hidden_dim) input = torch.randn(2, 5, 10) # 构造RMSNorm层 # 对最后一维归一化指定eps为1e-5并指定数据类型为fp16 rms_norm = nn.RMSNorm(10, eps = 1e-5, dtype = torch.float16) output = rms_norm(input) # 验证: 输出的均方根接近缩放值 print("RMSNorm 输出形状:", output.shape) # 同输入 print("RMSNorm 最后一维RMS:", torch.sqrt(torch.mean(output.pow(2), dim=-1))) # 接近scale值 |
父主题: 使用示例