PyTorch torch.ormqr 函数
torch.ormqr 是 PyTorch 中使用 QR 分解进行矩阵乘法的函数。它计算 Q @ input 或 input @ Q,其中 Q 是从 QR 分解得到的正交矩阵。
函数定义
torch.ormqr(input, tau, left=True, transpose=False, out=None)
参数:
input(Tensor): 输入矩阵。tau(Tensor): Householder 反射器系数。left(bool, 可选): 如果为 True,计算 Q @ input;否则计算 input @ Q。默认为 True。transpose(bool, 可选): 是否转置 Q。默认为 False。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回矩阵乘法结果。
使用示例
实例
import torch
# 创建矩阵
A = torch.randn(3, 3)
B = torch.randn(3, 4)
# QR 分解
Q, R = torch.linalg.qr(A)
# 计算 Q @ B
result = torch.ormqr(B, torch.zeros(3))
print("矩阵 B 形状:", B.shape)
print("结果形状:", result.shape)
# 创建矩阵
A = torch.randn(3, 3)
B = torch.randn(3, 4)
# QR 分解
Q, R = torch.linalg.qr(A)
# 计算 Q @ B
result = torch.ormqr(B, torch.zeros(3))
print("矩阵 B 形状:", B.shape)
print("结果形状:", result.shape)
输出结果为:
矩阵 B 形状: torch.Size([3, 4]) 结果形状: torch.Size([3, 4])

Pytorch torch 参考手册