现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.rot90 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.rot90 是 PyTorch 中用于将张量旋转 90 度的函数。

函数定义

torch.rot90(input, k=1, dims=[0, 1])

参数说明:

  • input - 输入张量
  • k - 旋转次数,正数表示逆时针,负数表示顺时针(默认 1)
  • dims - 旋转所在的平面维度(默认 [0, 1])

使用示例

实例

import torch

# 二维张量旋转
x = torch.arange(12).reshape(3, 4)
print("原始张量:")
print(x)

result = torch.rot90(x, 1)
print("逆时针旋转 90 度:")
print(result)

result = torch.rot90(x, -1)
print("n顺时针旋转 90 度:")
print(result)

result = torch.rot90(x, 2)
print("n旋转 180 度:")
print(result)

# 沿指定平面旋转(三维张量)
y = torch.arange(24).reshape(2, 3, 4)
print("n三维张量形状:", y.shape)

result = torch.rot90(y, 1, dims=[1, 2])
print("沿 [1, 2] 平面旋转 90 度后形状:", result.shape)

输出结果为:

原始张量:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
逆时针旋转 90 度:
tensor([[ 3,  7, 11],
        [ 2,  6, 10],
        [ 1,  5,  9],
        [ 0,  4,  8]])

顺时针旋转 90 度:
tensor([[ 8,  4,  0],
        [ 9,  5,  1],
        [10,  6,  2],
        [11,  7,  3]])

旋转 180 度:
tensor([[11, 10,  9,  8],
        [ 7,  6,  5,  4],
        [ 3,  2,  1,  0]])

三维张量形状: torch.Size([2, 3, 4])
沿 [1, 2] 平面旋转 90 度后形状: torch.Size([2, 4, 3])

Pytorch torch 参考手册 Pytorch torch 参考手册