PyTorch torch.diag 函数
torch.diag 是 PyTorch 中用于创建对角矩阵或提取张量对角元素的函数。
函数定义
torch.diag(input, diagonal=0, out=None)
参数说明:
input: 输入张量diagonal: 对角线索引,0 表示主对角线,正值表示上方对角线,负值表示下方对角线
使用示例
实例
import torch
# 创建一维张量
x = torch.tensor([1, 2, 3])
# 创建对角矩阵
y = torch.diag(x)
print(y)
# 创建一维张量
x = torch.tensor([1, 2, 3])
# 创建对角矩阵
y = torch.diag(x)
print(y)
输出结果为:
tensor([[1, 0, 0],
[0, 2, 0],
[0, 0, 3]])
实例
import torch
# 从矩阵中提取对角元素
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线
y = torch.diag(x)
print(y)
# 从矩阵中提取对角元素
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线
y = torch.diag(x)
print(y)
输出结果为:
tensor([1, 5, 9])

Pytorch torch 参考手册