现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.broadcast_shapes 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.broadcast_shapes 是 PyTorch 中用于计算多个形状广播结果的函数。它返回输入形状广播后得到的形状,不会实际创建张量。

函数定义

torch.broadcast_shapes(*shapes)

使用示例

实例

import torch

# 计算形状广播结果
result = torch.broadcast_shapes((3, 1), (1, 4), (3, 4))
print("(3,1) + (1,4) + (3,4) ->", result)
# 输出: (3, 4)

# 多个形状
result = torch.broadcast_shapes((5,), (1, 5), (3, 1, 5))
print("(5,) + (1,5) + (3,1,5) ->", result)
# 输出: (3, 1, 5)

# 单个形状
result = torch.broadcast_shapes((2, 3))
print("(2,3) ->", result)
# 输出: (2, 3)

# 无法广播的情况会报错
try:
    result = torch.broadcast_shapes((3,), (4,))
except RuntimeError as e:
    print("广播错误:", e)

Pytorch torch 参考手册 Pytorch torch 参考手册