现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.tril_indices 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.tril_indices 是 PyTorch 中用于生成下三角矩阵索引的函数。返回下三角部分(包括对角线)的行索引和列索引。

函数定义

torch.tril_indices(row, column, offset=0, dtype=torch.long, device='cpu')

参数说明:

  • row: 行数
  • column: 列数
  • offset: 对角线偏移量
  • dtype: 返回的数据类型
  • device: 设备

使用示例

实例

import torch

# 生成 3x3 矩阵的下三角索引
row, col = torch.tril_indices(3, 3)
print("row:", row)
print("col:", col)

输出结果为:

row: tensor([0, 1, 1, 2, 2, 2])
col: tensor([0, 0, 1, 0, 1, 2])

实例

import torch

# 生成索引并用于索引操作
row, col = torch.tril_indices(3, 3, offset=1)

# 创建 3x3 矩阵
a = torch.ones(3, 3)

# 使用索引设置下三角部分
a[row, col] = 0
print(a)

输出结果为:

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

实例

import torch

# 非方阵情况
row, col = torch.tril_indices(3, 4)
print("row:", row)
print("col:", col)

输出结果为:

row: tensor([0, 1, 1, 2, 2, 2])
col: tensor([0, 0, 1, 0, 1, 2])

Pytorch torch 参考手册 Pytorch torch 参考手册