permute函数用于张量维度转换。
1 | torchvision.ops.permute(dims:List[int]).contiguous() |
参数名 |
描述 |
取值范围 |
输入/输出 |
---|---|---|---|
dims |
序列,指定转置的维度顺序。 |
[2, 0, 1] |
输入 |
错误码 |
描述 |
---|---|
KP_PT_STS_NULL_PTR_ERR |
指针内存分配错误。 |
KP_PT_PERMUTE_MAP_WRN |
不支持的转置dims。 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import numpy as np import torch from torchvision import transforms, ops # 创建一个5x5的简单图像 src = np.array([[1, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 0, 1, 0], [0, 1, 1, 1, 0], [1, 0, 0, 0, 0]], dtype=np.float32) # 将NumPy数组转换为Tensor src_tensor = torch.from_numpy(src) # 定义维度重排的顺序 dims = [1, 0] # 使用torchvision.ops.permute进行图像维度重排 permute_op = ops.Permute(dims) permuted_tensor = permute_op(src_tensor) print(permuted_tensor) |
运行结果:
1 2 3 4 5 | tensor([[1., 0., 0., 0., 1.], [0., 1., 1., 1., 0.], [0., 1., 0., 1., 0.], [0., 1., 1., 1., 0.], [0., 0., 0., 0., 0.]]) |