现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.lu 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.lu 是 PyTorch 中用于计算矩阵 LU 分解的函数。LU 分解将矩阵 A 分解为 A = P * L * U,其中 L 是下三角矩阵,U 是上三角矩阵,P 是置换矩阵。

函数定义

torch.lu(A, pivot=True, get_infos=False, out=None)

参数:

  • A (Tensor): 输入矩阵。
  • pivot (bool, 可选): 是否进行 LU 列主元分解。默认为 True。
  • get_infos (bool, 可选): 如果为 True,返回信息。默认为 False。
  • out (tuple, 可选): 输出元组。

返回值:

  • tuple: 返回 (pivot, L, U) 的元组。

使用示例

实例

import torch

# 创建矩阵
A = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0],
                  [7.0, 8.0, 9.0]])

# LU 分解
LU, pivots = torch.lu(A, pivot=True)

print("矩阵 A:")
print(A)
print("nLU 矩阵:")
print(LU)
print("n主元索引:")
print(pivots)

输出结果为:

矩阵 A:
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
LU 矩阵:
tensor([[7., 8., 9.],
        [0.2, 0.4, 0.6],
        [0.6, 0.8, 0.0]])
主元索引:
tensor([3, 3, 3], dtype=torch.int32)

实例 - 使用 get_infos

import torch

A = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0],
                  [7.0, 8.0, 9.0]])

# LU 分解,包含信息
LU, pivots, info = torch.lu(A, pivot=True, get_infos=True)

print("信息:", info)

Pytorch torch 参考手册 Pytorch torch 参考手册