PyTorch torch.block_diag 函数
torch.block_diag 是 PyTorch 中用于创建块对角矩阵的函数。它将多个输入张量作为对角块沿对角线排列,形成一个更大的块对角矩阵。
函数定义
torch.block_diag(*tensors)
使用示例
实例
import torch
# 创建块对角矩阵
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6, 7]])
c = torch.tensor([[8], [9]])
result = torch.block_diag(a, b, c)
print("块对角矩阵:")
print(result)
# tensor([[1, 2, 0, 0, 0],
# [3, 4, 0, 0, 0],
# [0, 0, 5, 6, 7],
# [0, 0, 8, 0, 0],
# [0, 0, 9, 0, 0]])
# 多个矩阵块
m1 = torch.eye(2)
m2 = torch.eye(3)
result2 = torch.block_diag(m1, m2)
print("两个单位矩阵的块对角:")
print(result2)
# 创建块对角矩阵
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6, 7]])
c = torch.tensor([[8], [9]])
result = torch.block_diag(a, b, c)
print("块对角矩阵:")
print(result)
# tensor([[1, 2, 0, 0, 0],
# [3, 4, 0, 0, 0],
# [0, 0, 5, 6, 7],
# [0, 0, 8, 0, 0],
# [0, 0, 9, 0, 0]])
# 多个矩阵块
m1 = torch.eye(2)
m2 = torch.eye(3)
result2 = torch.block_diag(m1, m2)
print("两个单位矩阵的块对角:")
print(result2)

Pytorch torch 参考手册