PyTorch torch.rot90 函数
torch.rot90 是 PyTorch 中用于将张量旋转 90 度的函数。
函数定义
torch.rot90(input, k=1, dims=[0, 1])
参数说明:
input- 输入张量k- 旋转次数,正数表示逆时针,负数表示顺时针(默认 1)dims- 旋转所在的平面维度(默认 [0, 1])
使用示例
实例
import torch
# 二维张量旋转
x = torch.arange(12).reshape(3, 4)
print("原始张量:")
print(x)
result = torch.rot90(x, 1)
print("逆时针旋转 90 度:")
print(result)
result = torch.rot90(x, -1)
print("n顺时针旋转 90 度:")
print(result)
result = torch.rot90(x, 2)
print("n旋转 180 度:")
print(result)
# 沿指定平面旋转(三维张量)
y = torch.arange(24).reshape(2, 3, 4)
print("n三维张量形状:", y.shape)
result = torch.rot90(y, 1, dims=[1, 2])
print("沿 [1, 2] 平面旋转 90 度后形状:", result.shape)
# 二维张量旋转
x = torch.arange(12).reshape(3, 4)
print("原始张量:")
print(x)
result = torch.rot90(x, 1)
print("逆时针旋转 90 度:")
print(result)
result = torch.rot90(x, -1)
print("n顺时针旋转 90 度:")
print(result)
result = torch.rot90(x, 2)
print("n旋转 180 度:")
print(result)
# 沿指定平面旋转(三维张量)
y = torch.arange(24).reshape(2, 3, 4)
print("n三维张量形状:", y.shape)
result = torch.rot90(y, 1, dims=[1, 2])
print("沿 [1, 2] 平面旋转 90 度后形状:", result.shape)
输出结果为:
原始张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
逆时针旋转 90 度:
tensor([[ 3, 7, 11],
[ 2, 6, 10],
[ 1, 5, 9],
[ 0, 4, 8]])
顺时针旋转 90 度:
tensor([[ 8, 4, 0],
[ 9, 5, 1],
[10, 6, 2],
[11, 7, 3]])
旋转 180 度:
tensor([[11, 10, 9, 8],
[ 7, 6, 5, 4],
[ 3, 2, 1, 0]])
三维张量形状: torch.Size([2, 3, 4])
沿 [1, 2] 平面旋转 90 度后形状: torch.Size([2, 4, 3])

Pytorch torch 参考手册