鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

RMSNorm

场景说明

使用均方差对tensor做归一化,目前kdnn支持torch.float16类型,其他数据类型会走开源分支。

示例代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn

# 启用kdnn
torch._C._set_kdnn_enabled(True)

# 输入数据: (batch_size, seq_len, hidden_dim)
input = torch.randn(2, 5, 10)

# 构造RMSNorm层
# 对最后一维归一化指定eps为1e-5并指定数据类型为fp16
rms_norm = nn.RMSNorm(10, eps = 1e-5, dtype = torch.float16)  
output = rms_norm(input)

# 验证: 输出的均方根接近缩放值
print("RMSNorm 输出形状:", output.shape)  # 同输入
print("RMSNorm 最后一维RMS:", torch.sqrt(torch.mean(output.pow(2), dim=-1)))  # 接近scale值