conv2d_fft_?

FFT算法实现的2D卷积接口,仅支持膨胀系数为1的非膨胀卷积计算。

接口定义

C interface:

void conv2d_fft_fp32(const float *input, const int batch, const int inputChannels, const int inputHeight, const int inputWidth, const float *kernel, const int kernelHeight, const int kernelWidth, const int strideY, const int strideX, const int padHeight, const int padWidth, const float *bias, float *output, const int outputChannels);

void conv2d_fft_fp16(const __fp16 *input, const int batch, const int inputChannels, const int inputHeight, const int inputWidth, const __fp16 *kernel, const int kernelHeight, const int kernelWidth, const int strideY, const int strideX, const int padHeight, const int padWidth, const __fp16 *bias, __fp16 *output, const int outputChannels);

参数

参数名

类型

描述

输入/输出

input

conv2d_fft_fp32中是float类型

conv2d_fft_fp16中是__fp16类型

输入数据

输入

batch

int类型

输入数据的批数

输入

inputChannels

int类型

输入通道数

输入

inputHeight

int类型

输入数据的高度

输入

inputWidth

int类型

输入数据的宽度

输入

kernel

conv2d_fft_fp32中是float类型

conv2d_fft_fp16中是__fp16类型

卷积核

输入

kernelHeight

int类型

卷积核的高度

输入

kernelWidth

int类型

卷积核的宽度

输入

strideY

int类型

卷积核在高度方向移动的步长

输入

strideX

int类型

卷积核在宽度方向移动的步长

输入

padHeight

int类型

在原始input数据高度方向两端分别置零的长度

输入

padWidth

int类型

在原始input数据宽度方向两端分别置零的长度

输入

bias

conv2d_fft_fp32中是float类型

conv2d_fft_fp16中是__fp16类型

偏置值, 值为NULL时表示无偏置

输入

output

conv2d_fft_fp32中是float类型

conv2d_fft_fp16中是__fp16类型

输出结果数据

输出

outputChannels

int类型

输出通道数

输入

依赖

#include "conv.h"

示例

C interface:
    int batch = 1;
    int inputChannels = 1;
    int inputHeight = 6;
    int inputWidth = 6;
    int kernelHeight = 3;
    int kernelWidth = 3;
    int strideY = 1;
    int strideX = 1;
    int padHeight = 0;
    int padWidth = 0;
    int outputChannels = 1;
    float input[36] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
                       7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
                       13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
                       19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
                       25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
                       31.0, 32.0, 33.0, 34.0, 35.0, 36.0};
    float kernel[9] = {1.0, 2.0, 3.0,
                       4.0, 5.0, 6.0,
                       7.0, 8.0, 9.0};

    float *bias = NULL;
    int outputHeight = (inputHeight + 2 * padHeight - (kernelHeight - 1) - 1) / strideY + 1;
    int outputWidth = (inputWidth + 2 * padWidth - (kernelWidth - 1) - 1) / strideX + 1;
    /*
     *   outputHeight x outputWidth = 4 x 4
     */

    float output[16] = {0.0};
    conv2d_fft_fp32(input, batch, inputChannels, inputHeight, inputWidth, kernel, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, bias, output, outputChannels);
    /*
     * output = [474.0 519.0 564.0 609.0,
     *           744.0 789.0 834.0 879.0, 
     *           1014.0 1059.0 1104.0 1149.0, 
     *           1284.0 1329.0 1374.0 1419.0]
     */