现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.reshape 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.reshape 是 PyTorch 中用于改变张量形状的函数。它返回一个新的张量,其元素数量与原张量相同,但形状不同。

这是深度学习中非常常用的操作,用于调整数据形状以适应不同层的输入要求。

函数定义

torch.reshape(input, shape)

参数:

  • input (Tensor): 输入张量。
  • shape (tuple 或 int): 目标形状。形状中的元素个数必须与输入张量的元素个数相同。可以使用 -1 让 PyTorch 自动推断维度。

返回值:

  • torch.Tensor: 返回一个形状改变后的张量视图。

使用示例

示例 1: 将一维张量 reshape 为二维

实例

import torch

# 创建一维张量
x = torch.arange(12)

print("原始形状:", x.shape)
print(x)

# 改变为 3x4 的二维张量
y = torch.reshape(x, (3, 4))

print("新形状:", y.shape)
print(y)

输出结果为:

原始形状: torch.Size([12])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
新形状: torch.Size([3, 4])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

示例 2: 使用 -1 自动推断维度

实例

import torch

# 创建三维张量
x = torch.randn(2, 3, 4)

print("原始形状:", x.shape)

# 使用 -1 自动推断最后一维
y = torch.reshape(x, (2, -1))

print("新形状:", y.shape)

输出结果为:

原始形状: torch.Size([2, 3, 4])
新形状: torch.Size([2, 12])

示例 3: 展平为二维

实例

import torch

# 创建四维张量(典型批量图像数据)
x = torch.randn(32, 3, 224, 224)  # 批量大小 32,通道 3,高宽 224

print("原始形状:", x.shape)

# 展平为 (批量大小, 其他)
y = torch.reshape(x, (32, -1))

print("展平后形状:", y.shape)

输出结果为:

原始形状: torch.Size([32, 3, 224, 224])
展平后形状: torch.Size([32, 150528])

这在将图像数据送入全连接层之前经常使用。

示例 4: reshape 与 view 的区别

实例

import torch

# 创建连续张量
x = torch.arange(12).reshape(3, 4)

# reshape 可能返回视图或副本
y = torch.reshape(x, (4, 3))

print("y 是 x 的视图:", y.is_contiguous() or y.data_ptr() == x.data_ptr())

torch.reshape 可以处理非连续张量,而 tensor.view() 则要求张量必须是连续的。


注意事项

  • torch.reshape 返回的张量可能是原始数据的视图(view)也可能是副本(copy),取决于内存布局。
  • 如果需要保证返回视图,可以使用 tensor.view(),但要确保张量是连续的。
  • 形状的总元素数必须与原始张量的元素数相同。

Pytorch torch 参考手册 Pytorch torch 参考手册