PyTorch torch.triangular_solve 函数
torch.triangular_solve 是 PyTorch 中用于求解三角线性方程组的函数。该函数求解 AX = B,其中 A 是三角矩阵。
函数定义
torch.triangular_solve(A, B, upper, transpose, unitriangular)
参数说明
A: 系数矩阵(必须是方阵)B: 右侧矩阵或向量upper: A 是否为上三角矩阵(默认 True)transpose: 是否转置 A(默认 False)unitriangular: 是否使用单位三角(默认 False)
使用示例
实例
import torch
# 创建上三角系数矩阵
A = torch.tensor([[3.0, 1.0, 2.0],
[0.0, 2.0, 1.0],
[0.0, 0.0, 1.0]])
# 右侧向量
B = torch.tensor([9.0, 5.0, 2.0])
# 求解 AX = B
X = torch.triangular_solve(B.unsqueeze(1), A, upper=True)
print("解 X:")
print(X.solution)
# 创建上三角系数矩阵
A = torch.tensor([[3.0, 1.0, 2.0],
[0.0, 2.0, 1.0],
[0.0, 0.0, 1.0]])
# 右侧向量
B = torch.tensor([9.0, 5.0, 2.0])
# 求解 AX = B
X = torch.triangular_solve(B.unsqueeze(1), A, upper=True)
print("解 X:")
print(X.solution)
输出结果为:
解 X:
tensor([[1.],
[2.],
[2.]])

Pytorch torch 参考手册