现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.solve 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.solve 是 PyTorch 中用于求解线性方程组的函数。它求解 AX = B,其中 A 是系数矩阵,B 是右侧矩阵或向量。

函数定义

torch.solve(B, A, out=None)

参数:

  • B (Tensor): 右侧矩阵或向量。
  • A (Tensor): 系数矩阵。
  • out (tuple, 可选): 输出元组。

返回值:

  • tuple: 返回 (X, LU) 的元组,其中 X 是解,LU 是分解。

使用示例

实例

import torch

# 创建系数矩阵和右侧向量
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([5.0, 11.0])

# 求解 AX = B
X, _ = torch.solve(B, A)

print("系数矩阵 A:")
print(A)
print("n右侧向量 B:")
print(B)
print("n解 X:")
print(X)
print("n验证: A @ X =")
print(A @ X)

输出结果为:

系数矩阵 A:
tensor([[1., 2.],
        [3., 4.]])
右侧向量 B:
tensor([ 5., 11.])
解 X:
tensor([1., 2.])
验证: A @ X =
tensor([ 5., 11.])

Pytorch torch 参考手册 Pytorch torch 参考手册