鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

Conv

场景说明

Conv2d:常用于边缘检测、特征提取。

Conv3d:常用于时空特征提取。

目前KDNN支持torch.float16和torch.float32数据类型,其他数据类型会走开源分支。

示例代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn as nn

#使能KDNN
torch._C._set_kdnn_enabled(True)

# Conv2d示例
conv2d = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1)
x = torch.randn(8, 3, 32, 32)  # [batch, channel, H, W] # 默认为fp32类型
y = conv2d(x)  # 输出 [8, 16, 30, 30]
print("Conv2d输出形状", y.shape)
print(y)

# Conv3d示例
conv3d = nn.Conv3d(1, 8, kernel_size=(3,3,3)) 
x = torch.randn(4, 1, 10, 64, 64)  # [batch, channel, depth, H, W] # 默认为fp32类型
y = conv3d(x)  # 输出 [4, 8, 8, 62, 62]

print("Conv3d输出形状", y.shape)
print(y)