conv1d_gemm_?
GEMM算法实现的通用1D卷积接口。
接口定义
C interface:
void conv1d_gemm_fp32(const float *input, const int batch, const int inputChannels, const int inputLength, const float* kernel, const int kernelLength, const int stride, const int padLength, const int dilation, const float *bias, float *output, const int outputChannels);
void conv1d_gemm_fp16(const __fp16 *input, const int batch, const int inputChannels, const int inputLength, const __fp16 *kernel, const int kernelLength, const int stride, const int padLength, const int dilation, const __fp16 *bias, __fp16 *output, const int outputChannels);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
input |
|
输入数据 |
输入 |
batch |
int类型 |
输入数据的批数 |
输入 |
inputChannels |
int类型 |
输入通道数 |
输入 |
inputLength |
int类型 |
输入数据的长度 |
输入 |
kernel |
|
卷积核 |
输入 |
kernelLength |
int类型 |
卷积核长度 |
输入 |
stride |
int类型 |
步长 |
输入 |
padLength |
int类型 |
在原始input数据两端分别置零的长度 |
输入 |
dilation |
int类型 |
膨胀系数 |
输入 |
bias |
|
偏置值, 值为NULL时表示无偏置 |
输入 |
output |
|
输出结果数据 |
输出 |
outputChannels |
int类型 |
输出通道数 |
输入 |
依赖
#include "conv.h"
示例
C interface:
int batch = 1;
int inputChannels = 1;
int inputLength = 10;
int kernelLength = 3;
int stride = 1;
int padLength = 0;
int dilation = 1;
int outputChannels = 1;
float input[10] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0};
float kernel[3] = {1.0, 2.0, 3.0};
float *bias = NULL;
int outputLength = (inputLength + 2 * padLength - dilation * (kernelLength - 1) - 1) / stride + 1;
/*
* outputLength = 8
*/
float output[8] = {0.0};
conv1d_gemm_fp32(input, batch, inputChannels, inputLength, kernel, kernelLength, stride, padLength, dilation, bias, output, outputChannels);
/*
* output = [14.0 20.0 26.0 32.0 38.0 44.0 50.0 56.0]
*/