鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

Linear

场景说明

对输入tensor进行线性变换,目前kdnn支持torch.int8、torch.float16和torch.float32数据类型,其他数据类型会走开源分支。

示例代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn
#使能KDNN
torch._C._set_kdnn_enabled(True)

# linear示例,
# 输入数据: (batch, in_features),默认数据类型torch.float32
input_tensor = torch.randn(128, 20)

# 构造linear层
linear = nn.Linear(20, 30, bias=True) # input_feature为20 output_feature为30

#向前计算
ouput_tensor = linear(input_tensor) # output_tensor shape  [128,30]

# 打印输出的形状和数值
print("linear输出形状", ouput_tensor.shape)
print(ouput_tensor)