PyTorch torch.svd_lowrank 函数
torch.svd_lowrank 是 PyTorch 中用于计算矩阵低秩近似 SVD 的函数。它使用随机方法计算部分 SVD,比完整 SVD 更快。
函数定义
torch.svd_lowrank(A, q=6, niter=2, mexp=2)
参数:
A(Tensor): 输入矩阵。q(int, 可选): 幂迭代次数。默认为 6。niter(int, 可选): 随机迭代次数。默认为 2。mexp(int, 可选): 矩阵指数。默认为 2。
返回值:
tuple: 返回 (U, S, V) 的元组。
使用示例
实例
import torch
# 创建大矩阵
A = torch.randn(100, 50)
# 低秩 SVD (计算前 10 个奇异值)
U, S, V = torch.svd_lowrank(A, q=10)
print("矩阵 A 形状:", A.shape)
print("U 形状:", U.shape)
print("奇异值 S 形状:", S.shape)
print("V 形状:", V.shape)
print("n奇异值:")
print(S[:10])
# 创建大矩阵
A = torch.randn(100, 50)
# 低秩 SVD (计算前 10 个奇异值)
U, S, V = torch.svd_lowrank(A, q=10)
print("矩阵 A 形状:", A.shape)
print("U 形状:", U.shape)
print("奇异值 S 形状:", S.shape)
print("V 形状:", V.shape)
print("n奇异值:")
print(S[:10])
输出结果为:
矩阵 A 形状: torch.Size([100, 50])
U 形状: torch.Size([100, 10])
奇异值 S 形状: torch.Size([10])
V 形状: torch.Size([50, 10])
奇异值:
tensor([14.6526, 14.0891, 13.5830, 13.0966, 12.7677, 12.4337, 12.1510,
11.8361, 11.5309, 11.3064])

Pytorch torch 参考手册