PyTorch torch.tril_indices 函数
torch.tril_indices 是 PyTorch 中用于生成下三角矩阵索引的函数。返回下三角部分(包括对角线)的行索引和列索引。
函数定义
torch.tril_indices(row, column, offset=0, dtype=torch.long, device='cpu')
参数说明:
row: 行数column: 列数offset: 对角线偏移量dtype: 返回的数据类型device: 设备
使用示例
实例
import torch
# 生成 3x3 矩阵的下三角索引
row, col = torch.tril_indices(3, 3)
print("row:", row)
print("col:", col)
# 生成 3x3 矩阵的下三角索引
row, col = torch.tril_indices(3, 3)
print("row:", row)
print("col:", col)
输出结果为:
row: tensor([0, 1, 1, 2, 2, 2]) col: tensor([0, 0, 1, 0, 1, 2])
实例
import torch
# 生成索引并用于索引操作
row, col = torch.tril_indices(3, 3, offset=1)
# 创建 3x3 矩阵
a = torch.ones(3, 3)
# 使用索引设置下三角部分
a[row, col] = 0
print(a)
# 生成索引并用于索引操作
row, col = torch.tril_indices(3, 3, offset=1)
# 创建 3x3 矩阵
a = torch.ones(3, 3)
# 使用索引设置下三角部分
a[row, col] = 0
print(a)
输出结果为:
tensor([[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]])
实例
import torch
# 非方阵情况
row, col = torch.tril_indices(3, 4)
print("row:", row)
print("col:", col)
# 非方阵情况
row, col = torch.tril_indices(3, 4)
print("row:", row)
print("col:", col)
输出结果为:
row: tensor([0, 1, 1, 2, 2, 2]) col: tensor([0, 0, 1, 0, 1, 2])

Pytorch torch 参考手册