PyTorch torch.transpose 函数
torch.transpose 是 PyTorch 中用于交换张量两个维度的函数。它返回输入张量的转置视图。
这是深度学习中常用的操作,用于改变数据形状以适应不同的计算需求。
函数定义
torch.transpose(input, dim0, dim1)
参数:
input(Tensor): 输入张量。dim0(int): 第一个要交换的维度。dim1(int): 第二个要交换的维度。
返回值:
torch.Tensor: 返回转置后的张量视图。
使用示例
示例 1: 二维矩阵转置
实例
import torch
# 创建 3x4 矩阵
x = torch.randn(3, 4)
# 转置
y = torch.transpose(x, 0, 1)
print("原始形状:", x.shape)
print("转置后形状:", y.shape)
print("原始:")
print(x)
print("转置后:")
print(y)
# 创建 3x4 矩阵
x = torch.randn(3, 4)
# 转置
y = torch.transpose(x, 0, 1)
print("原始形状:", x.shape)
print("转置后形状:", y.shape)
print("原始:")
print(x)
print("转置后:")
print(y)
输出结果为:
原始形状: torch.Size([3, 4])
转置后形状: torch.Size([4, 3])
原始:
tensor([[ 0.3364, -0.7844, 0.9760, 0.4381],
[ 0.7865, -1.2775, 0.5767, -0.5268],
[-0.6399, -0.6743, -0.2972, -0.4781]])
转置后:
tensor([[ 0.3364, 0.7865, -0.6399],
[-0.7844, -1.2775, -0.6743],
[ 0.9760, 0.5767, -0.2972],
[ 0.4381, -0.5268, -0.4781]])
示例 2: 多维张量转置
实例
import torch
# 创建 3D 张量
x = torch.randn(2, 3, 4)
# 交换 dim=1 和 dim=2
y = torch.transpose(x, 1, 2)
print("原始形状:", x.shape)
print("转置后形状:", y.shape)
# 创建 3D 张量
x = torch.randn(2, 3, 4)
# 交换 dim=1 和 dim=2
y = torch.transpose(x, 1, 2)
print("原始形状:", x.shape)
print("转置后形状:", y.shape)
输出结果为:
原始形状: torch.Size([2, 3, 4]) 转置后形状: torch.Size([2, 4, 3])
注意事项
torch.transpose返回的是视图,不是副本。- 对于二维张量,也可以使用
tensor.t()方法。

Pytorch torch 参考手册