现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.diagonal 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.diagonal 是 PyTorch 中用于提取张量对角线元素的函数。它返回输入张量指定对角线的视图。

函数定义

torch.diagonal(input, diagonal=0, dim1=0, dim2=1)

参数说明:

  • input: 输入张量
  • diagonal: 对角线索引,0 表示主对角线
  • dim1: 第一个维度
  • dim2: 第二个维度

使用示例

实例

import torch

# 创建 3x3 矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 提取主对角线
y = torch.diagonal(x)
print(y)

输出结果为:

tensor([1, 5, 9])

实例

import torch

# 创建 3x3 矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 提取主对角线上方的对角线
y = torch.diagonal(x, offset=1)
print(y)

输出结果为:

tensor([2, 6])

实例

import torch

# 创建 3D 张量
x = torch.arange(12).reshape(2, 3, 4)

# 提取指定维度的对角线
y = torch.diagonal(x, offset=0, dim1=1, dim2=2)
print(y.shape)
print(y)

输出结果为:

torch.Size([2, 3])
tensor([[ 0,  5, 10],
        [12, 17, 22]])

Pytorch torch 参考手册 Pytorch torch 参考手册