PyTorch torch.broadcast_tensors 函数
torch.broadcast_tensors 是 PyTorch 中用于对多个张量进行广播的函数。它将输入张量广播到相同的形状,返回一组可以逐元素操作的张量。广播规则遵循 NumPy 的广播机制。
函数定义
torch.broadcast_tensors(*tensors)
使用示例
实例
import torch
# 不同形状的张量广播
x = torch.tensor([[1, 2, 3]]) # 形状 (1, 3)
y = torch.tensor([[1], [2], [3]]) # 形状 (3, 1)
a, b = torch.broadcast_tensors(x, y)
print("x 广播后形状:", a.shape)
print("y 广播后形状:", b.shape)
print("广播后的 x:")
print(a)
print("广播后的 y:")
print(b)
print("逐元素相加:")
print(a + b)
# 多个张量
x1 = torch.randn(3, 1, 5)
x2 = torch.randn(1, 4, 5)
x3 = torch.randn(3, 4, 1)
r1, r2, r3 = torch.broadcast_tensors(x1, x2, x3)
print("广播后形状:", r1.shape)
# 不同形状的张量广播
x = torch.tensor([[1, 2, 3]]) # 形状 (1, 3)
y = torch.tensor([[1], [2], [3]]) # 形状 (3, 1)
a, b = torch.broadcast_tensors(x, y)
print("x 广播后形状:", a.shape)
print("y 广播后形状:", b.shape)
print("广播后的 x:")
print(a)
print("广播后的 y:")
print(b)
print("逐元素相加:")
print(a + b)
# 多个张量
x1 = torch.randn(3, 1, 5)
x2 = torch.randn(1, 4, 5)
x3 = torch.randn(3, 4, 1)
r1, r2, r3 = torch.broadcast_tensors(x1, x2, x3)
print("广播后形状:", r1.shape)

Pytorch torch 参考手册