PyTorch torch.fake_quantize_per_channel_affine 函数
torch.fake_quantize_per_channel_affine 是 PyTorch 中用于对张量进行通道级假量化(Fake Quantization)的函数,常用于量化训练。
函数定义
torch.fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max)
使用示例
实例
import torch
# 创建输入张量
x = torch.randn(2, 3, 4, 5)
# 定义缩放因子和零点
scale = torch.tensor([1.0, 1.2, 1.5])
zero_point = torch.tensor([0, 0, 0], dtype=torch.long)
# 进行通道级假量化
y = torch.fake_quantize_per_channel_affine(x, scale, zero_point, axis=1, quant_min=0, quant_max=255)
print("量化后的形状:", y.shape)
# 创建输入张量
x = torch.randn(2, 3, 4, 5)
# 定义缩放因子和零点
scale = torch.tensor([1.0, 1.2, 1.5])
zero_point = torch.tensor([0, 0, 0], dtype=torch.long)
# 进行通道级假量化
y = torch.fake_quantize_per_channel_affine(x, scale, zero_point, axis=1, quant_min=0, quant_max=255)
print("量化后的形状:", y.shape)

Pytorch torch 参考手册