使用说明
接口定义
初始化RMSNormalizationLayerFWD。构造时需要传入输入矩阵、缩放矩阵、输出矩阵的tensor信息,其中statsinfo是均值和方差的描述信息。
参数名称 |
数据类型 |
描述 |
取值范围 |
|---|---|---|---|
srcInfo |
KDNN::TensorInfo |
输入矩阵信息。 |
{shape{A, ... , D}, type, layout} |
statsInfo |
KDNN::TensorInfo |
均值和方差的信息。 |
{shape{A, ...}, type, layout} |
scaleShiftInfo |
KDNN::TensorInfo |
缩放矩阵信息。 |
{shape{D}, type, layout} |
dstInfo |
KDNN::TensorInfo |
输出矩阵信息。 |
{shape{A, ... , D}, type, layout} |
flags |
KDNN::NormalizationFlags |
枚举类,用于选择归一化方式。 |
默认NONE,可选以下参数:
|
假设输入数据形状为{A,B,C,D} 则在最后一个维度即D维上进行归一化,放缩和偏移操作也在D维度上进行。
执行算子运算。src和dst分别为输入和输出的指针,scale和shift为放缩和偏移,mean和variance为均值和方差的指针,saveStats表示是否保存计算的均值和方差。eps用于避免除0异常。
参数名称 |
数据类型 |
描述 |
取值范围 |
|---|---|---|---|
src |
void* |
输入指针。 |
- |
dst |
void* |
输出指针。 |
- |
scale |
void* |
放缩指针。 |
- |
shift |
void* |
偏移指针。 |
- |
mean |
float* |
均值指针。 |
- |
variance |
float* |
方差指针。 |
- |
saveStats |
bool |
是否保存计算方差。 |
布尔值 |
eps |
float |
避免除0异常。 |
浮点数 |
验证RMSNormalizationLayerFWD的输入参数,并在算子构造过程中自动触发执行。
参数名称 |
数据类型 |
描述 |
取值范围 |
|---|---|---|---|
srcInfo |
KDNN::TensorInfo |
输入矩阵信息。 |
{shape{A, ... , D}, type, layout} |
statsInfo |
KDNN::TensorInfo |
均值和方差的信息。 |
{shape{A, ...}, type, layout} |
scaleShiftInfo |
KDNN::TensorInfo |
缩放矩阵信息。 |
{shape{D}, type, layout} |
dstInfo |
KDNN::TensorInfo |
输出矩阵信息。 |
{shape{A, ... , D}, type, layout} |
flags |
KDNN::NormalizationFlags |
枚举类,用于选择归一化方式。 |
默认NONE,可选:
|
支持数据类型
- RMSNorm支持fp16数据类型。(TensorInfo对象初始化时需传入Shape、Type、Layout参数,此处列出为Type支持数据类型。)
表4 TensorInfo对象初始化时支持的Type类型 srcInfo
statInfo
scaleShiftInfo
KDNN::Element::Type::F16(fp16)
KDNN::Element::Type::F16(fp16)
KDNN::Element::Type::F16(fp16)
使用示例
使用fp16数据类型均方根归一化,srcInfo数据排布为KDNN::Layout::AB,statInfo数据排布为KDNN::Layout::A,scaleShiftInfo数据排布为KDNN::Layout::A,dstInfo数据排布KDNN::Layout::AB。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | using SizeType = KDNN::SizeType; using Shape = KDNN::Shape; using Type KDNN::Element::TypeT Shape shape(100, 100); //定义张量信息 TensorInfo srcInfo = {shape, Type::F16, KDNN::Layout::AB}; TensorInfo statInfo = {{shape[0]}, Type::F16, KDNN::Layout::A}; TensorInfo scaleInfo = {{shape[1]}, Type::F16, KDNN::Layout::A}; TensorInfo dstInfo = {shape, Type::F16, KDNN::Layout::AB}; KDNN::NormalizationFlags flags = KDNN::NormalizationFlags::NONE; // 构造算子 KDNN::RMSNormalizationLayerFWD rmsLayer1(srcInfo, statInfo, scaleInfo, dstInfo, flags); //初始化矩阵数据 SizeType srcSize = srcInfo.GetTotalTensorSize(); SizeType dstSize = dstInfo.GetTotalTensorSize(); SizeType statSize = statInfo.GetTotalTensorSize(); SizeType innerSize = scaleInfo.GetTotalTensorSize(); __fp16 *src = (__fp16 *)malloc(srcSize * sizeof(__fp16)); __fp16 *dst = (__fp16 *)malloc(dstSize * sizeof(__fp16)); __fp16 *dstRef = (__fp16 *)malloc(dstSize * sizeof(__fp16)); float *variance = (float *)malloc(statSize * sizeof(float)); __fp16 *scale = (__fp16 *)malloc(innerSize * sizeof(__fp16)); float eps = 1e-5; // 执行算子 rmsLayer1.Run(src, dst, scale, variance, true, eps); |