PyTorch torch.lu_solve 函数
torch.lu_solve 是 PyTorch 中使用 LU 分解求解线性方程组的函数。它利用已经计算好的 LU 分解来高效地求解 AX = B。
函数定义
torch.lu_solve(B, LU, pivots, out=None)
参数:
B(Tensor): 右侧矩阵或向量。LU(Tensor): LU 分解得到的矩阵。pivots(Tensor): LU 分解的主元索引。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回线性方程组的解。
使用示例
实例
import torch
# 创建矩阵和右侧向量
A = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]], dtype=torch.float64)
B = torch.tensor([14.0, 32.0, 50.0], dtype=torch.float64)
# LU 分解
LU, pivots = torch.lu(A)
# 使用 LU 分解求解
X = torch.lu_solve(B, LU, pivots)
print("矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)
# 创建矩阵和右侧向量
A = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]], dtype=torch.float64)
B = torch.tensor([14.0, 32.0, 50.0], dtype=torch.float64)
# LU 分解
LU, pivots = torch.lu(A)
# 使用 LU 分解求解
X = torch.lu_solve(B, LU, pivots)
print("矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)
输出结果为:
矩阵 A:
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]], dtype=torch.float64)
右侧向量 B:
tensor([14., 32., 50.], dtype=torch.float64)
解 X:
tensor([1., 2., 3.], dtype=torch.float64)
验证: A @ X =
tensor([14., 32., 50.], dtype=torch.float64)

Pytorch torch 参考手册