鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

GroupNorm

场景说明

使用均方差对tensor做归一化,目前kdnn支持torch.float16和torch.float32数据类型,其他数据类型会走开源分支。

示例代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn

#使能KDNN
torch._C._set_kdnn_enabled(True)

# 输入数据: (batch_size, channels, height, width)
input = torch.randn(2, 6, 32, 32)  # 2张图片6个通道32x32分辨率 默认为fp32数据类型

# 定义 GroupNorm (分3组,每组2个通道)
group_norm = nn.GroupNorm(num_groups=3, num_channels=6)  # 6个通道分成3组 默认为fp32数据类型

# 前向计算
output = group_norm(input)

# 验证: 每组内归一化
print("GroupNorm 输出形状:", output.shape)  # 保持原形状