PyTorch torch.linalg.solve 函数
torch.linalg.solve 是 PyTorch 线性代数模块中用于求解线性方程组的函数。它求解 AX = B。
函数定义
torch.linalg.solve(A, B, left=True, out=None)
参数:
A(Tensor): 系数矩阵。B(Tensor): 右侧矩阵或向量。left(bool, 可选): 如果为 True,求解 AX = B;否则求 XA = B。默认为 True。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回线性方程组的解。
使用示例
实例
import torch
# 创建系数矩阵和右侧向量
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([5.0, 11.0])
# 求解 AX = B
X = torch.linalg.solve(A, B)
print("系数矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)
# 创建系数矩阵和右侧向量
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([5.0, 11.0])
# 求解 AX = B
X = torch.linalg.solve(A, B)
print("系数矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)
输出结果为:
系数矩阵 A:
tensor([[1., 2.],
[3., 4.]])
右侧向量 B:
tensor([ 5., 11.])
解 X:
tensor([1., 2.])
验证: A @ X =
tensor([ 5., 11.])

Pytorch torch 参考手册