PyTorch torch.tril 函数
torch.tril 是 PyTorch 中用于提取矩阵下三角部分(包括主对角线)的函数。上三角部分会被设为0。
函数定义
torch.tril(input, diagonal=0, out=None)
参数说明:
input: 输入张量diagonal: 对角线索引,0 表示主对角线out: 输出张量
使用示例
实例
import torch
# 创建矩阵
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取下三角部分
y = torch.tril(a)
print(y)
# 创建矩阵
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取下三角部分
y = torch.tril(a)
print(y)
输出结果为:
tensor([[1, 0, 0],
[4, 5, 0],
[7, 8, 9]])
实例
import torch
# 创建矩阵
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线以下第一对角线的下三角
y = torch.tril(a, diagonal=1)
print(y)
# 创建矩阵
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 提取主对角线以下第一对角线的下三角
y = torch.tril(a, diagonal=1)
print(y)
输出结果为:
tensor([[1, 2, 0],
[4, 5, 6],
[7, 8, 9]])
实例
import torch
# 创建非方阵
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 提取下三角部分
y = torch.tril(a)
print(y)
# 创建非方阵
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 提取下三角部分
y = torch.tril(a)
print(y)
输出结果为:
tensor([[1, 0, 0],
[4, 5, 0]])

Pytorch torch 参考手册