现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.transpose 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.transpose 是 PyTorch 中用于交换张量两个维度的函数。它返回输入张量的转置视图。

这是深度学习中常用的操作,用于改变数据形状以适应不同的计算需求。

函数定义

torch.transpose(input, dim0, dim1)

参数:

  • input (Tensor): 输入张量。
  • dim0 (int): 第一个要交换的维度。
  • dim1 (int): 第二个要交换的维度。

返回值:

  • torch.Tensor: 返回转置后的张量视图。

使用示例

示例 1: 二维矩阵转置

实例

import torch

# 创建 3x4 矩阵
x = torch.randn(3, 4)

# 转置
y = torch.transpose(x, 0, 1)

print("原始形状:", x.shape)
print("转置后形状:", y.shape)
print("原始:")
print(x)
print("转置后:")
print(y)

输出结果为:

原始形状: torch.Size([3, 4])
转置后形状: torch.Size([4, 3])
原始:
tensor([[ 0.3364, -0.7844,  0.9760,  0.4381],
        [ 0.7865, -1.2775,  0.5767, -0.5268],
        [-0.6399, -0.6743, -0.2972, -0.4781]])
转置后:
tensor([[ 0.3364,  0.7865, -0.6399],
        [-0.7844, -1.2775, -0.6743],
        [ 0.9760,  0.5767, -0.2972],
        [ 0.4381, -0.5268, -0.4781]])

示例 2: 多维张量转置

实例

import torch

# 创建 3D 张量
x = torch.randn(2, 3, 4)

# 交换 dim=1 和 dim=2
y = torch.transpose(x, 1, 2)

print("原始形状:", x.shape)
print("转置后形状:", y.shape)

输出结果为:

原始形状: torch.Size([2, 3, 4])
转置后形状: torch.Size([2, 4, 3])

注意事项

  • torch.transpose 返回的是视图,不是副本。
  • 对于二维张量,也可以使用 tensor.t() 方法。

Pytorch torch 参考手册 Pytorch torch 参考手册