PyTorch torch.broadcast_shapes 函数
torch.broadcast_shapes 是 PyTorch 中用于计算多个形状广播结果的函数。它返回输入形状广播后得到的形状,不会实际创建张量。
函数定义
torch.broadcast_shapes(*shapes)
使用示例
实例
import torch
# 计算形状广播结果
result = torch.broadcast_shapes((3, 1), (1, 4), (3, 4))
print("(3,1) + (1,4) + (3,4) ->", result)
# 输出: (3, 4)
# 多个形状
result = torch.broadcast_shapes((5,), (1, 5), (3, 1, 5))
print("(5,) + (1,5) + (3,1,5) ->", result)
# 输出: (3, 1, 5)
# 单个形状
result = torch.broadcast_shapes((2, 3))
print("(2,3) ->", result)
# 输出: (2, 3)
# 无法广播的情况会报错
try:
result = torch.broadcast_shapes((3,), (4,))
except RuntimeError as e:
print("广播错误:", e)
# 计算形状广播结果
result = torch.broadcast_shapes((3, 1), (1, 4), (3, 4))
print("(3,1) + (1,4) + (3,4) ->", result)
# 输出: (3, 4)
# 多个形状
result = torch.broadcast_shapes((5,), (1, 5), (3, 1, 5))
print("(5,) + (1,5) + (3,1,5) ->", result)
# 输出: (3, 1, 5)
# 单个形状
result = torch.broadcast_shapes((2, 3))
print("(2,3) ->", result)
# 输出: (2, 3)
# 无法广播的情况会报错
try:
result = torch.broadcast_shapes((3,), (4,))
except RuntimeError as e:
print("广播错误:", e)

Pytorch torch 参考手册