使用说明
接口定义
初始化RopeLayerFWD。构造时需要传入最大支持的文本上下文长度、head的特征维度、计算角度theta的基数。
参数名称 |
数据类型 |
描述 |
取值范围 |
|---|---|---|---|
max_position |
int |
最大的上下文长度。 |
[1, INT_MAX] |
dim |
int |
head的特征维度。 |
[1, INT_MAX] |
base |
float |
计算theta的底数。 |
计算theta的底数 |
scaling_factor |
float |
theta的放缩因子。 |
theta的放缩因子 |
执行算子运算。
Run(void *query, void *key, const int qStride, const int kStride,const int batchSize, const int seqLen, const int qHeads, const int kHeads,const int *positionIds)->void
参数名称 |
数据类型 |
描述 |
取值范围 |
|---|---|---|---|
query |
void* |
需要应用旋转位置编码的数据指针。 |
- |
key |
void* |
需要应用旋转位置编码的数据指针。 |
- |
qStride |
int |
query的相邻距离。 |
[1, INT_MAX] |
kStride |
int |
Key的相邻距离。 |
[1, INT_MAX] |
batchSize |
int |
批大小。 |
[1, INT_MAX] |
seqLen |
int |
语句长度。 |
[1, INT_MAX] |
qHeads |
int |
query头个数。 |
[1, INT_MAX] |
kHeads |
int |
key头个数。 |
[1, INT_MAX] |
positionlds |
int* |
token的位置编号列表。 |
- |
kind |
KDNN::RopeKind |
rope算法实现类型 |
KDNN::RopeKind::OpenSora KDNN::RopeKind::DeepSeek |
验证RopeLayerFWD的输入参数,并在算子构造过程中自动触发执行。
ValidateInput(const int maxPosition, const int dim, const float base, const float scalingFactor)->KDNN::Status
参数名称 |
数据类型 |
描述 |
取值范围 |
|---|---|---|---|
max_position |
int |
最大的上下文长度。 |
[1, INT_MAX] |
dim |
int |
head的特征维度。 |
[1, INT_MAX] |
base |
float |
计算theta的底数。 |
计算theta的底数 |
scaling_factor |
float |
theta的放缩因子。 |
theta的放缩因子 |
支持的数据类型
query |
key |
|---|---|
fp16 |
fp16 |
支持的内存排布顺序为(B, S, H, D)。
- B:表示bath_size批大小。
- S:表示seq_len语句长度。
- H:表示head_num头数量。
- D:表示head_dim头特征维度。
使用示例
在最大的上下文长度max_position为10000,head的特征维度dim为128时,对query和key进行旋转位置编码。
1 2 3 4 5 6 7 8 9 10 11 12 13 | int max_position = 10000, dim = 128; int batchSize = 2, seqLen = 128, qHeads = 28, kHeads = 28; //对应内存排布顺序为(B, S, H, D) // 构造算子 KDNN::RopeLayerFWD layer(max_position, dim); int qStride = qHeads * dim, kStride = kHeads * dim; int qSize = batchSize * seqLen * qHeads * dim; int kSize = batchSize * seqLen * kHeads * dim; // 入参出参内存申请 __fp16 *query = (__fp16 *)malloc(qSize * sizeof(__fp16)); __fp16 *key = (__fp16 *)malloc(kSize * sizeof(__fp16)); int *positionIds = (int *)malloc(seqLen * sizeof(int)); // 执行算子 layer.Run(query, key, qStride, kStride, batchSize, seqLen, qHeads, kHeads, positionIds); |