现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.expand 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.expand 是 PyTorch 中用于扩展张量尺寸的函数。它通过复制视图的方式扩展张量,而不会实际复制数据。

扩展后的张量在内存中是共享原始数据的视图,因此这是一种内存高效的操作。

函数定义

torch.expand(*sizes)
torch.expand(input, *sizes)

参数:

  • input (Tensor): 输入张量。
  • *sizes (torch.Size 或 int): 目标尺寸。尺寸中可以有单个维度为 -1,表示保持该维度不变。值为1的维度可以扩展为更大的尺寸。

返回值:

  • torch.Tensor: 返回扩展后的张量视图。

使用示例

实例

import torch

# 创建一个维度为1的张量
x = torch.tensor([[1], [2], [3]])

print("原始张量:")
print(x)
print("形状:", x.shape)

# 扩展到更大的尺寸
y = x.expand(3, 4)

print("n扩展后:")
print(y)
print("形状:", y.shape)

输出结果为:

原始张量:
tensor([[1],
        [2],
        [3]])
形状: torch.Size([3, 1])

扩展后:
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
形状: torch.Size([3, 4])

实例

import torch

# 使用 -1 保持维度不变
x = torch.tensor([1, 2, 3, 4])

# 扩展为 2x4 的张量
y = x.expand(2, -1)

print("原始:", x.shape)
print("扩展后:", y.shape)
print(y)

输出结果为:

原始: torch.Size([4])
扩展后: torch.Size([2, 4])
tensor([[1, 2, 3, 4],
        [1, 2, 3, 4]])

实例

import torch

# 广播机制的应用场景
# 将列向量扩展为矩阵
column = torch.randn(5, 1)

# 扩展为与矩阵相加
matrix = torch.randn(5, 10)

# 广播:列向量会自动扩展
result = matrix + column

print("列向量形状:", column.shape)
print("矩阵形状:", matrix.shape)
print("结果形状:", result.shape)

输出结果为:

列向量形状: torch.Size([5, 1])
矩阵形状: torch.Size([5, 10])
结果形状: torch.Size([5, 10])

实例

import torch

# 从2D张量扩展到3D
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 扩展第三个维度
y = x.expand(2, 2, 3)

print("原始形状:", x.shape)
print("扩展后形状:", y.shape)
print("n扩展后的张量:")
print(y)

# 验证数据是共享的
print("n修改原始张量:")
x[0, 0] = 100
print("x[0,0]:", x[0, 0])
print("y[0,0,0]:", y[0, 0, 0])

输出结果为:

原始形状: torch.Size([2, 3])
扩展后形状: torch.Size([2, 2, 3])

扩展后的张量:
tensor([[[1, 2, 3],
         [4, 5, 6]],

        [[1, 2, 3],
         [4, 5, 6]]])

修改原始张量:
x[0,0]: tensor(100)
y[0,0,0]: tensor(100)


注意:torch.expand 只可以扩展维度大小为1的维度到更大的尺寸。不能将大于1的维度缩小或扩展到不同的尺寸。返回的是视图,不是副本。


Pytorch torch 参考手册 Pytorch torch 参考手册