PyTorch torch.expand 函数
torch.expand 是 PyTorch 中用于扩展张量尺寸的函数。它通过复制视图的方式扩展张量,而不会实际复制数据。
扩展后的张量在内存中是共享原始数据的视图,因此这是一种内存高效的操作。
函数定义
torch.expand(*sizes) torch.expand(input, *sizes)
参数:
input(Tensor): 输入张量。*sizes(torch.Size 或 int): 目标尺寸。尺寸中可以有单个维度为 -1,表示保持该维度不变。值为1的维度可以扩展为更大的尺寸。
返回值:
torch.Tensor: 返回扩展后的张量视图。
使用示例
实例
import torch
# 创建一个维度为1的张量
x = torch.tensor([[1], [2], [3]])
print("原始张量:")
print(x)
print("形状:", x.shape)
# 扩展到更大的尺寸
y = x.expand(3, 4)
print("n扩展后:")
print(y)
print("形状:", y.shape)
# 创建一个维度为1的张量
x = torch.tensor([[1], [2], [3]])
print("原始张量:")
print(x)
print("形状:", x.shape)
# 扩展到更大的尺寸
y = x.expand(3, 4)
print("n扩展后:")
print(y)
print("形状:", y.shape)
输出结果为:
原始张量:
tensor([[1],
[2],
[3]])
形状: torch.Size([3, 1])
扩展后:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
形状: torch.Size([3, 4])
实例
import torch
# 使用 -1 保持维度不变
x = torch.tensor([1, 2, 3, 4])
# 扩展为 2x4 的张量
y = x.expand(2, -1)
print("原始:", x.shape)
print("扩展后:", y.shape)
print(y)
# 使用 -1 保持维度不变
x = torch.tensor([1, 2, 3, 4])
# 扩展为 2x4 的张量
y = x.expand(2, -1)
print("原始:", x.shape)
print("扩展后:", y.shape)
print(y)
输出结果为:
原始: torch.Size([4])
扩展后: torch.Size([2, 4])
tensor([[1, 2, 3, 4],
[1, 2, 3, 4]])
实例
import torch
# 广播机制的应用场景
# 将列向量扩展为矩阵
column = torch.randn(5, 1)
# 扩展为与矩阵相加
matrix = torch.randn(5, 10)
# 广播:列向量会自动扩展
result = matrix + column
print("列向量形状:", column.shape)
print("矩阵形状:", matrix.shape)
print("结果形状:", result.shape)
# 广播机制的应用场景
# 将列向量扩展为矩阵
column = torch.randn(5, 1)
# 扩展为与矩阵相加
matrix = torch.randn(5, 10)
# 广播:列向量会自动扩展
result = matrix + column
print("列向量形状:", column.shape)
print("矩阵形状:", matrix.shape)
print("结果形状:", result.shape)
输出结果为:
列向量形状: torch.Size([5, 1]) 矩阵形状: torch.Size([5, 10]) 结果形状: torch.Size([5, 10])
实例
import torch
# 从2D张量扩展到3D
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 扩展第三个维度
y = x.expand(2, 2, 3)
print("原始形状:", x.shape)
print("扩展后形状:", y.shape)
print("n扩展后的张量:")
print(y)
# 验证数据是共享的
print("n修改原始张量:")
x[0, 0] = 100
print("x[0,0]:", x[0, 0])
print("y[0,0,0]:", y[0, 0, 0])
# 从2D张量扩展到3D
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 扩展第三个维度
y = x.expand(2, 2, 3)
print("原始形状:", x.shape)
print("扩展后形状:", y.shape)
print("n扩展后的张量:")
print(y)
# 验证数据是共享的
print("n修改原始张量:")
x[0, 0] = 100
print("x[0,0]:", x[0, 0])
print("y[0,0,0]:", y[0, 0, 0])
输出结果为:
原始形状: torch.Size([2, 3])
扩展后形状: torch.Size([2, 2, 3])
扩展后的张量:
tensor([[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]])
修改原始张量:
x[0,0]: tensor(100)
y[0,0,0]: tensor(100)
注意:torch.expand 只可以扩展维度大小为1的维度到更大的尺寸。不能将大于1的维度缩小或扩展到不同的尺寸。返回的是视图,不是副本。

Pytorch torch 参考手册