Conv
场景说明
Conv2d:常用于边缘检测、特征提取。
Conv3d:常用于时空特征提取。
目前KDNN支持torch.float16和torch.float32数据类型,其他数据类型会走开源分支。
示例代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import torch import torch.nn as nn #使能KDNN torch._C._set_kdnn_enabled(True) # Conv2d示例 conv2d = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1) x = torch.randn(8, 3, 32, 32) # [batch, channel, H, W] # 默认为fp32类型 y = conv2d(x) # 输出 [8, 16, 30, 30] print("Conv2d输出形状", y.shape) print(y) # Conv3d示例 conv3d = nn.Conv3d(1, 8, kernel_size=(3,3,3)) x = torch.randn(4, 1, 10, 64, 64) # [batch, channel, depth, H, W] # 默认为fp32类型 y = conv3d(x) # 输出 [4, 8, 8, 62, 62] print("Conv3d输出形状", y.shape) print(y) |
父主题: 使用示例