PyTorch torch.reshape 函数
torch.reshape 是 PyTorch 中用于改变张量形状的函数。它返回一个新的张量,其元素数量与原张量相同,但形状不同。
这是深度学习中非常常用的操作,用于调整数据形状以适应不同层的输入要求。
函数定义
torch.reshape(input, shape)
参数:
input(Tensor): 输入张量。shape(tuple 或 int): 目标形状。形状中的元素个数必须与输入张量的元素个数相同。可以使用-1让 PyTorch 自动推断维度。
返回值:
torch.Tensor: 返回一个形状改变后的张量视图。
使用示例
示例 1: 将一维张量 reshape 为二维
实例
import torch
# 创建一维张量
x = torch.arange(12)
print("原始形状:", x.shape)
print(x)
# 改变为 3x4 的二维张量
y = torch.reshape(x, (3, 4))
print("新形状:", y.shape)
print(y)
# 创建一维张量
x = torch.arange(12)
print("原始形状:", x.shape)
print(x)
# 改变为 3x4 的二维张量
y = torch.reshape(x, (3, 4))
print("新形状:", y.shape)
print(y)
输出结果为:
原始形状: torch.Size([12])
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
新形状: torch.Size([3, 4])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
示例 2: 使用 -1 自动推断维度
实例
import torch
# 创建三维张量
x = torch.randn(2, 3, 4)
print("原始形状:", x.shape)
# 使用 -1 自动推断最后一维
y = torch.reshape(x, (2, -1))
print("新形状:", y.shape)
# 创建三维张量
x = torch.randn(2, 3, 4)
print("原始形状:", x.shape)
# 使用 -1 自动推断最后一维
y = torch.reshape(x, (2, -1))
print("新形状:", y.shape)
输出结果为:
原始形状: torch.Size([2, 3, 4]) 新形状: torch.Size([2, 12])
示例 3: 展平为二维
实例
import torch
# 创建四维张量(典型批量图像数据)
x = torch.randn(32, 3, 224, 224) # 批量大小 32,通道 3,高宽 224
print("原始形状:", x.shape)
# 展平为 (批量大小, 其他)
y = torch.reshape(x, (32, -1))
print("展平后形状:", y.shape)
# 创建四维张量(典型批量图像数据)
x = torch.randn(32, 3, 224, 224) # 批量大小 32,通道 3,高宽 224
print("原始形状:", x.shape)
# 展平为 (批量大小, 其他)
y = torch.reshape(x, (32, -1))
print("展平后形状:", y.shape)
输出结果为:
原始形状: torch.Size([32, 3, 224, 224]) 展平后形状: torch.Size([32, 150528])
这在将图像数据送入全连接层之前经常使用。
示例 4: reshape 与 view 的区别
实例
import torch
# 创建连续张量
x = torch.arange(12).reshape(3, 4)
# reshape 可能返回视图或副本
y = torch.reshape(x, (4, 3))
print("y 是 x 的视图:", y.is_contiguous() or y.data_ptr() == x.data_ptr())
# 创建连续张量
x = torch.arange(12).reshape(3, 4)
# reshape 可能返回视图或副本
y = torch.reshape(x, (4, 3))
print("y 是 x 的视图:", y.is_contiguous() or y.data_ptr() == x.data_ptr())
torch.reshape 可以处理非连续张量,而 tensor.view() 则要求张量必须是连续的。
注意事项
torch.reshape返回的张量可能是原始数据的视图(view)也可能是副本(copy),取决于内存布局。- 如果需要保证返回视图,可以使用
tensor.view(),但要确保张量是连续的。 - 形状的总元素数必须与原始张量的元素数相同。

Pytorch torch 参考手册