PyTorch torch.quantize_per_channel 函数
torch.quantize_per_channel 是 PyTorch 中用于创建按通道(per-channel)量化的量化张量的函数。
函数定义
torch.quantize_per_channel(input, scales, zero_points, axis, dtype)
使用示例
实例
import torch
# 创建输入张量 (2D)
input = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
# 按通道量化
# scales: 每个通道的缩放因子
# zero_points: 每个通道的零点
# axis: 指定沿哪个维度进行量化
scales = torch.tensor([0.1, 0.2])
zero_points = torch.tensor([10, 10])
axis = 0
x = torch.quantize_per_channel(input, scales, zero_points, axis, dtype=torch.quint8)
print("量化张量:")
print(x)
print("原始值反量化:")
print(x.dequantize())
# 创建输入张量 (2D)
input = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
# 按通道量化
# scales: 每个通道的缩放因子
# zero_points: 每个通道的零点
# axis: 指定沿哪个维度进行量化
scales = torch.tensor([0.1, 0.2])
zero_points = torch.tensor([10, 10])
axis = 0
x = torch.quantize_per_channel(input, scales, zero_points, axis, dtype=torch.quint8)
print("量化张量:")
print(x)
print("原始值反量化:")
print(x.dequantize())

Pytorch torch 参考手册