PyTorch torch.diagflat 函数
torch.diagflat 是 PyTorch 中用于创建对角扁平矩阵的函数。无论输入是一维还是多维,都会被展平后作为对角线元素。
函数定义
torch.diagflat(input, diagonal=0)
参数说明:
input: 输入张量,会被展平diagonal: 对角线索引
使用示例
实例
import torch
# 创建一维张量
x = torch.tensor([1, 2, 3])
# 创建对角扁平矩阵
y = torch.diagflat(x)
print(y)
# 创建一维张量
x = torch.tensor([1, 2, 3])
# 创建对角扁平矩阵
y = torch.diagflat(x)
print(y)
输出结果为:
tensor([[1, 0, 0],
[0, 2, 0],
[0, 0, 3]])
实例
import torch
# 创建二维张量
x = torch.tensor([[1, 2], [3, 4]])
# 会被展平为 [1,2,3,4],创建 4x4 对角矩阵
y = torch.diagflat(x)
print(y)
# 创建二维张量
x = torch.tensor([[1, 2], [3, 4]])
# 会被展平为 [1,2,3,4],创建 4x4 对角矩阵
y = torch.diagflat(x)
print(y)
输出结果为:
tensor([[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]])

Pytorch torch 参考手册