PyTorch torch.diagonal 函数
torch.diagonal 是 PyTorch 中用于提取张量对角线元素的函数。它返回输入张量指定对角线的视图。
函数定义
torch.diagonal(input, diagonal=0, dim1=0, dim2=1)
参数说明:
input: 输入张量diagonal: 对角线索引,0 表示主对角线dim1: 第一个维度dim2: 第二个维度
使用示例
实例
import torch
# 创建 3x3 矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线
y = torch.diagonal(x)
print(y)
# 创建 3x3 矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线
y = torch.diagonal(x)
print(y)
输出结果为:
tensor([1, 5, 9])
实例
import torch
# 创建 3x3 矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线上方的对角线
y = torch.diagonal(x, offset=1)
print(y)
# 创建 3x3 矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线上方的对角线
y = torch.diagonal(x, offset=1)
print(y)
输出结果为:
tensor([2, 6])
实例
import torch
# 创建 3D 张量
x = torch.arange(12).reshape(2, 3, 4)
# 提取指定维度的对角线
y = torch.diagonal(x, offset=0, dim1=1, dim2=2)
print(y.shape)
print(y)
# 创建 3D 张量
x = torch.arange(12).reshape(2, 3, 4)
# 提取指定维度的对角线
y = torch.diagonal(x, offset=0, dim1=1, dim2=2)
print(y.shape)
print(y)
输出结果为:
torch.Size([2, 3])
tensor([[ 0, 5, 10],
[12, 17, 22]])

Pytorch torch 参考手册