现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.block_diag 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.block_diag 是 PyTorch 中用于创建块对角矩阵的函数。它将多个输入张量作为对角块沿对角线排列,形成一个更大的块对角矩阵。

函数定义

torch.block_diag(*tensors)

使用示例

实例

import torch

# 创建块对角矩阵
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6, 7]])
c = torch.tensor([[8], [9]])

result = torch.block_diag(a, b, c)
print("块对角矩阵:")
print(result)
# tensor([[1, 2, 0, 0, 0],
#         [3, 4, 0, 0, 0],
#         [0, 0, 5, 6, 7],
#         [0, 0, 8, 0, 0],
#         [0, 0, 9, 0, 0]])

# 多个矩阵块
m1 = torch.eye(2)
m2 = torch.eye(3)
result2 = torch.block_diag(m1, m2)
print("两个单位矩阵的块对角:")
print(result2)

Pytorch torch 参考手册 Pytorch torch 参考手册