PyTorch torch.cholesky_solve 函数
torch.cholesky_solve 是 PyTorch 中使用 Cholesky 分解求解线性方程组的函数。它利用 Cholesky 分解来高效地求解 AX = B。
函数定义
torch.cholesky_solve(B, L, upper=False, out=None)
参数:
B(Tensor): 线性方程组的右侧矩阵或向量。L(Tensor): Cholesky 分解得到的上三角或下三角矩阵。upper(bool, 可选): 如果为 True,L 是上三角矩阵;否则是下三角矩阵。默认为 False。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回线性方程组的解 X。
使用示例
实例
import torch
# 创建一个对称正定矩阵 A
A = torch.tensor([[4.0, 2.0, 2.0],
[2.0, 5.0, 3.0],
[2.0, 3.0, 6.0]], dtype=torch.float64)
# 创建右侧向量 B
B = torch.tensor([3, 4, 5], dtype=torch.float64)
# Cholesky 分解
L = torch.cholesky(A)
# 使用 Cholesky 求解 AX = B
X = torch.cholesky_solve(B, L)
print("矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)
# 创建一个对称正定矩阵 A
A = torch.tensor([[4.0, 2.0, 2.0],
[2.0, 5.0, 3.0],
[2.0, 3.0, 6.0]], dtype=torch.float64)
# 创建右侧向量 B
B = torch.tensor([3, 4, 5], dtype=torch.float64)
# Cholesky 分解
L = torch.cholesky(A)
# 使用 Cholesky 求解 AX = B
X = torch.cholesky_solve(B, L)
print("矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)
输出结果为:
矩阵 A:
tensor([[4., 2., 2.],
[2., 5., 3.],
[2., 3., 6.]], dtype=torch.float64)
右侧向量 B:
tensor([3., 4., 5.], dtype=torch.float64)
解 X:
tensor([-2.5000, 5.0000, -2.0000], dtype=torch.float64)
验证: A @ X =
tensor([3., 4., 5.], dtype=torch.float64)

Pytorch torch 参考手册