PyTorch torch.expand_as 函数
torch.expand_as 是 PyTorch 中用于将张量扩展到与另一个张量相同尺寸的函数。它是 torch.expand 的便捷版本,自动使用参考张量的尺寸进行扩展。
函数定义
torch.expand_as(other)
参数:
other(Tensor): 参考张量,当前张量将被扩展到与该张量相同的尺寸。
返回值:
torch.Tensor: 返回扩展后的张量视图。
使用示例
实例
import torch
# 创建一个向量
x = torch.tensor([1, 2, 3, 4])
# 创建一个目标矩阵
other = torch.randn(3, 4)
# 将 x 扩展为与 other 相同的形状
y = x.expand_as(other)
print("原始向量形状:", x.shape)
print("参考张量形状:", other.shape)
print("扩展后形状:", y.shape)
print("n扩展后的张量:")
print(y)
# 创建一个向量
x = torch.tensor([1, 2, 3, 4])
# 创建一个目标矩阵
other = torch.randn(3, 4)
# 将 x 扩展为与 other 相同的形状
y = x.expand_as(other)
print("原始向量形状:", x.shape)
print("参考张量形状:", other.shape)
print("扩展后形状:", y.shape)
print("n扩展后的张量:")
print(y)
输出结果为:
原始向量形状: torch.Size([4])
参考张量形状: torch.Size([3, 4])
扩展后形状: torch.Size([3, 4])
扩展后的张量:
tensor([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])
实例
import torch
# 创建一个列向量
x = torch.tensor([[1], [2], [3]])
# 创建一个目标3D张量
other = torch.randn(3, 4, 5)
# 将 x 扩展为与 other 相同的形状
y = x.expand_as(other)
print("原始列向量形状:", x.shape)
print("参考张量形状:", other.shape)
print("扩展后形状:", y.shape)
# 创建一个列向量
x = torch.tensor([[1], [2], [3]])
# 创建一个目标3D张量
other = torch.randn(3, 4, 5)
# 将 x 扩展为与 other 相同的形状
y = x.expand_as(other)
print("原始列向量形状:", x.shape)
print("参考张量形状:", other.shape)
print("扩展后形状:", y.shape)
输出结果为:
原始列向量形状: torch.Size([3, 1]) 参考张量形状: torch.Size([3, 4, 5]) 扩展后形状: torch.Size([3, 4, 5])
实例
import torch
# 在神经网络中的应用示例
# 假设有偏置向量需要广播到特征图
bias = torch.tensor([0.1, 0.2, 0.3]) # 3个通道的偏置
# 模拟卷积输出的特征图 (batch, channel, height, width)
feature_map = torch.randn(8, 3, 32, 32)
# 将偏置扩展到与特征图相同的形状
bias_expanded = bias.expand_as(feature_map)
print("偏置形状:", bias.shape)
print("特征图形状:", feature_map.shape)
print("扩展后偏置形状:", bias_expanded.shape)
# 添加偏置
output = feature_map + bias_expanded
print("n加上偏置后的输出形状:", output.shape)
# 在神经网络中的应用示例
# 假设有偏置向量需要广播到特征图
bias = torch.tensor([0.1, 0.2, 0.3]) # 3个通道的偏置
# 模拟卷积输出的特征图 (batch, channel, height, width)
feature_map = torch.randn(8, 3, 32, 32)
# 将偏置扩展到与特征图相同的形状
bias_expanded = bias.expand_as(feature_map)
print("偏置形状:", bias.shape)
print("特征图形状:", feature_map.shape)
print("扩展后偏置形状:", bias_expanded.shape)
# 添加偏置
output = feature_map + bias_expanded
print("n加上偏置后的输出形状:", output.shape)
输出结果为:
偏置形状: torch.Size([3]) 特征图形状: torch.Size([8, 3, 32, 32]) 扩展后偏置形状: torch.Size([8, 3, 32, 32]) 加上偏置后的输出形状: torch.Size([8, 3, 32, 32])
注意:torch.expand_as 只可以扩展维度大小为1的维度到更大的尺寸。other 张量的尺寸必须与扩展后的尺寸兼容(即当前张量的维度可以通过广播扩展到目标尺寸)。返回的是视图,不是副本。

Pytorch torch 参考手册