GroupNorm
场景说明
使用均方差对tensor做归一化,目前kdnn支持torch.float16和torch.float32数据类型,其他数据类型会走开源分支。
示例代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch import torch.nn as nn #使能KDNN torch._C._set_kdnn_enabled(True) # 输入数据: (batch_size, channels, height, width) input = torch.randn(2, 6, 32, 32) # 2张图片,6个通道,32x32分辨率 默认为fp32数据类型 # 定义 GroupNorm (分3组,每组2个通道) group_norm = nn.GroupNorm(num_groups=3, num_channels=6) # 6个通道分成3组 默认为fp32数据类型 # 前向计算 output = group_norm(input) # 验证: 每组内归一化 print("GroupNorm 输出形状:", output.shape) # 保持原形状 |
父主题: 使用示例