PyTorch torch.triu_indices 函数
torch.triu_indices 是 PyTorch 中用于生成上三角矩阵索引的函数。返回上三角部分(包括对角线)的行索引和列索引。
函数定义
torch.triu_indices(row, column, offset=0, dtype=torch.long, device='cpu')
参数说明:
row: 行数column: 列数offset: 对角线偏移量dtype: 返回的数据类型device: 设备
使用示例
实例
import torch
# 生成 3x3 矩阵的上三角索引
row, col = torch.triu_indices(3, 3)
print("row:", row)
print("col:", col)
# 生成 3x3 矩阵的上三角索引
row, col = torch.triu_indices(3, 3)
print("row:", row)
print("col:", col)
输出结果为:
row: tensor([0, 0, 0, 1, 1, 2]) col: tensor([0, 1, 2, 1, 2, 2])
实例
import torch
# 生成索引并用于索引操作
row, col = torch.triu_indices(3, 3, offset=1)
# 创建 3x3 矩阵
a = torch.ones(3, 3)
# 使用索引设置上三角部分
a[row, col] = 0
print(a)
# 生成索引并用于索引操作
row, col = torch.triu_indices(3, 3, offset=1)
# 创建 3x3 矩阵
a = torch.ones(3, 3)
# 使用索引设置上三角部分
a[row, col] = 0
print(a)
输出结果为:
tensor([[1., 1., 1.],
[0., 1., 1.],
[0., 0., 1.]])
实例
import torch
# 非方阵情况
row, col = torch.triu_indices(3, 4)
print("row:", row)
print("col:", col)
# 非方阵情况
row, col = torch.triu_indices(3, 4)
print("row:", row)
print("col:", col)
输出结果为:
row: tensor([0, 0, 0, 1, 1, 2]) col: tensor([0, 1, 2, 1, 2, 2])

Pytorch torch 参考手册