PyTorch torch.renorm 函数
torch.renorm 是 PyTorch 中用于重新归一化张量的函数。它根据指定的范数对张量进行归一化,使得指定维度上的元素范数不超过给定值。
函数定义
torch.renorm(input, p, dim, maxnorm)
参数说明:
input: 输入张量p: 范数阶数dim: 归一化的维度maxnorm: 最大范数值
使用示例
实例
import torch
# 创建张量
x = torch.tensor([[2.0, 4.0, 6.0], [3.0, 6.0, 9.0]])
# 在 dim=1 上进行 L2 范数归一化,最大范数为 1
y = torch.renorm(x, p=2, dim=1, maxnorm=1)
print(y)
# 创建张量
x = torch.tensor([[2.0, 4.0, 6.0], [3.0, 6.0, 9.0]])
# 在 dim=1 上进行 L2 范数归一化,最大范数为 1
y = torch.renorm(x, p=2, dim=1, maxnorm=1)
print(y)
输出结果为:
tensor([[0.2673, 0.5345, 0.8018],
[0.2673, 0.5345, 0.8018]])
实例
import torch
# 创建张量
x = torch.tensor([[1.0, 2.0, 3.0]])
# 在 dim=1 上进行 L1 范数归一化
y = torch.renorm(x, p=1, dim=1, maxnorm=1)
print(y)
# 创建张量
x = torch.tensor([[1.0, 2.0, 3.0]])
# 在 dim=1 上进行 L1 范数归一化
y = torch.renorm(x, p=1, dim=1, maxnorm=1)
print(y)
输出结果为:
tensor([[0.1667, 0.3333, 0.5000]])

Pytorch torch 参考手册