现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.lu_unpack 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.lu_unpack 是 PyTorch 中用于从 LU 分解的结果中解压出 L、U 和 P 矩阵的函数。

函数定义

torch.lu_unpack(LU_data, LU_pivots, unpack_data=True, out=None)

参数:

  • LU_data (Tensor): LU 分解得到的矩阵。
  • LU_pivots (Tensor): LU 分解的主元索引。
  • unpack_data (bool, 可选): 是否解压数据。默认为 True。
  • out (tuple, 可选): 输出元组。

返回值:

  • tuple: 返回 (pivot, L, U) 的元组。

使用示例

实例

import torch

# 创建矩阵
A = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0],
                  [7.0, 8.0, 9.0]])

# LU 分解
LU, pivots = torch.lu(A)

# 解压出 L、U、P
P, L, U = torch.lu_unpack(LU, pivots)

print("矩阵 A:")
print(A)
print("n置换矩阵 P:")
print(P)
print("n下三角矩阵 L:")
print(L)
print("n上三角矩阵 U:")
print(U)
print("n验证: P @ L @ U =")
print(P @ L @ U)

输出结果为:

矩阵 A:
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
置换矩阵 P:
tensor([[0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]])
下三角矩阵 L:
tensor([[1.0000, 0.0000, 0.0000],
        [0.1429, 1.0000, 0.0000],
        [0.5714, 0.5000, 1.0000]])
上三角矩阵 U:
tensor([[7.0000, 8.0000, 9.0000],
        [0.0000, 0.8571, 1.7143],
        [0.0000, 0.0000, 0.0000]])
验证: P @ L @ U =
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

Pytorch torch 参考手册 Pytorch torch 参考手册