鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

RMSNorm

接口功能

LayerNorm的轻量变体,仅使用均方根值归一化(无均值中心化)。

函数原型

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor):
        # 计算均方根值
        rms = x.norm(2, dim=-1, keepdim=True) * (x.size(-1) ** -0.5)
        return x / (rms + self.eps) * self.weight

参数说明

表1 参数说明

参数

类型

必填

说明

dim

int

输入特征维度。

eps

float

数值稳定性常数(默认1e-6)。

weight

tensor

可学习的缩放因子(自动创建)。