Linear
场景说明
对输入tensor进行线性变换,目前kdnn支持torch.int8、torch.float16和torch.float32数据类型,其他数据类型会走开源分支。
示例代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch import torch.nn as nn #使能KDNN torch._C._set_kdnn_enabled(True) # linear示例, # 输入数据: (batch, in_features),默认数据类型torch.float32 input_tensor = torch.randn(128, 20) # 构造linear层 linear = nn.Linear(20, 30, bias=True) # input_feature为20, output_feature为30 #向前计算 ouput_tensor = linear(input_tensor) # output_tensor shape 为 [128,30] # 打印输出的形状和数值 print("linear输出形状", ouput_tensor.shape) print(ouput_tensor) |
父主题: 使用示例