鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

使用说明

接口定义

初始化RMSNormalizationLayerFWD。构造时需要传入输入矩阵、缩放矩阵、输出矩阵的tensor信息,其中statsinfo是均值和方差的描述信息。

RMSNormalizationLayerFWD(const TensorInfo &srcInfo, const TensorInfo &statsInfo, const TensorInfo &scaleInfo,const TensorInfo &dstInfo, NormalizationFlags flags)->void
表1 RMSNormalizationLayerFWD函数输入参数

参数名称

数据类型

描述

取值范围

srcInfo

KDNN::TensorInfo

输入矩阵信息。

{shape{A, ... , D}, type, layout}

statsInfo

KDNN::TensorInfo

均值和方差的信息。

{shape{A, ...}, type, layout}

scaleShiftInfo

KDNN::TensorInfo

缩放矩阵信息。

{shape{D}, type, layout}

dstInfo

KDNN::TensorInfo

输出矩阵信息。

{shape{A, ... , D}, type, layout}

flags

KDNN::NormalizationFlags

枚举类,用于选择归一化方式。

默认NONE,可选以下参数:

  • USE_GLOBAL_STATS使用传入的mean和var来计算。
  • USE_SCALE对归一化结果再放缩。
  • USE_SHIFT对归一化结果再偏移。

假设输入数据形状为{A,B,C,D} 则在最后一个维度即D维上进行归一化,放缩和偏移操作也在D维度上进行。

执行算子运算。src和dst分别为输入和输出的指针,scale和shift为放缩和偏移,mean和variance为均值和方差的指针,saveStats表示是否保存计算的均值和方差。eps用于避免除0异常。

Run(const void *src, void *dst, const void *scale, float *variance,bool saveStats, const float eps) ->void
表2 Run函数输入参数

参数名称

数据类型

描述

取值范围

src

void*

输入指针。

-

dst

void*

输出指针。

-

scale

void*

放缩指针。

-

shift

void*

偏移指针。

-

mean

float*

均值指针。

-

variance

float*

方差指针。

-

saveStats

bool

是否保存计算方差。

布尔值

eps

float

避免除0异常。

浮点数

验证RMSNormalizationLayerFWD的输入参数,并在算子构造过程中自动触发执行。

ValidateInput(const TensorInfo &srcInfo, const TensorInfo &statsInfo,const TensorInfo &scaleInfo, const TensorInfo &dstInfo,NormalizationFlags flags) ->KDNN::Status
表3 ValidateInput输入参数

参数名称

数据类型

描述

取值范围

srcInfo

KDNN::TensorInfo

输入矩阵信息。

{shape{A, ... , D}, type, layout}

statsInfo

KDNN::TensorInfo

均值和方差的信息。

{shape{A, ...}, type, layout}

scaleShiftInfo

KDNN::TensorInfo

缩放矩阵信息。

{shape{D}, type, layout}

dstInfo

KDNN::TensorInfo

输出矩阵信息。

{shape{A, ... , D}, type, layout}

flags

KDNN::NormalizationFlags

枚举类,用于选择归一化方式。

默认NONE,可选:

  • USE_GLOBAL_STATS使用传入的mean和var来计算。
  • USE_SCALE对归一化结果再放缩。
  • USE_SHIFT对归一化结果再偏移。

支持数据类型

  • RMSNorm支持fp16数据类型。(TensorInfo对象初始化时需传入Shape、Type、Layout参数,此处列出为Type支持数据类型。)
    表4 TensorInfo对象初始化时支持的Type类型

    srcInfo

    statInfo

    scaleShiftInfo

    KDNN::Element::Type::F16(fp16)

    KDNN::Element::Type::F16(fp16)

    KDNN::Element::Type::F16(fp16)

  • 最高支持5Dtensor,支持顺序数据排布:a、ab、abc、abcd、abcde。

    对应KDNN::Layout::A、KDNN::Layout::AB、KDNN::Layout::ABC、KDNN::Layout::ABCD、KDNN::Layout::ABCDE。

    表5 TensorInfo对象初始化时支持的Layout类型

    dimension

    srcInfo数据排布

    dstInfo数据排布

    2D

    ab

    ab

    3D

    abc

    abc

    4D

    abcd

    abcd

    5D

    abcde

    abcde

使用示例

使用fp16数据类型均方根归一化,srcInfo数据排布为KDNN::Layout::AB,statInfo数据排布为KDNN::Layout::A,scaleShiftInfo数据排布为KDNN::Layout::A,dstInfo数据排布KDNN::Layout::AB。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
using SizeType = KDNN::SizeType;
using Shape = KDNN::Shape;
using Type KDNN::Element::TypeT
Shape shape(100, 100);
//定义张量信息
TensorInfo srcInfo = {shape, Type::F16, KDNN::Layout::AB};
TensorInfo statInfo = {{shape[0]}, Type::F16, KDNN::Layout::A};
TensorInfo scaleInfo = {{shape[1]}, Type::F16, KDNN::Layout::A};
TensorInfo dstInfo = {shape, Type::F16, KDNN::Layout::AB};
KDNN::NormalizationFlags flags = KDNN::NormalizationFlags::NONE;
// 构造算子
KDNN::RMSNormalizationLayerFWD rmsLayer1(srcInfo, statInfo, scaleInfo, dstInfo, flags);
//初始化矩阵数据
SizeType srcSize = srcInfo.GetTotalTensorSize();
SizeType dstSize = dstInfo.GetTotalTensorSize();
SizeType statSize = statInfo.GetTotalTensorSize();
SizeType innerSize = scaleInfo.GetTotalTensorSize();
__fp16 *src = (__fp16 *)malloc(srcSize * sizeof(__fp16));
__fp16 *dst = (__fp16 *)malloc(dstSize * sizeof(__fp16));
__fp16 *dstRef = (__fp16 *)malloc(dstSize * sizeof(__fp16));
float *variance = (float *)malloc(statSize * sizeof(float));
__fp16 *scale = (__fp16 *)malloc(innerSize * sizeof(__fp16));
float eps = 1e-5;
// 执行算子
rmsLayer1.Run(src, dst, scale, variance, true, eps);