现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.triangular_solve 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.triangular_solve 是 PyTorch 中用于求解三角线性方程组的函数。该函数求解 AX = B,其中 A 是三角矩阵。

函数定义

torch.triangular_solve(A, B, upper, transpose, unitriangular)

参数说明

  • A: 系数矩阵(必须是方阵)
  • B: 右侧矩阵或向量
  • upper: A 是否为上三角矩阵(默认 True)
  • transpose: 是否转置 A(默认 False)
  • unitriangular: 是否使用单位三角(默认 False)

使用示例

实例

import torch

# 创建上三角系数矩阵
A = torch.tensor([[3.0, 1.0, 2.0],
                  [0.0, 2.0, 1.0],
                  [0.0, 0.0, 1.0]])

# 右侧向量
B = torch.tensor([9.0, 5.0, 2.0])

# 求解 AX = B
X = torch.triangular_solve(B.unsqueeze(1), A, upper=True)

print("解 X:")
print(X.solution)

输出结果为:

解 X:
tensor([[1.],
        [2.],
        [2.]])

Pytorch torch 参考手册 Pytorch torch 参考手册