PyTorch torch.quantized_batch_norm 函数
torch.quantized_batch_norm 是 PyTorch 中用于对量化张量执行批归一化的函数。该函数在量化模型推理时非常有用,可以在保持量化优势的同时进行批归一化操作。
函数定义
torch.quantized_batch_norm(input, weight, bias, mean, var, eps, output_scale, output_zero_point)
参数说明
input: 输入的量化张量weight: 批归一化的缩放参数bias: 批归一化的偏置参数mean: 批归一化的均值var: 批归一化的方差eps: 防止除零的小常数output_scale: 输出张量的量化 scaleoutput_zero_point: 输出张量的量化 zero point
使用示例
实例
import torch
# 创建量化输入张量
input = torch.quantize_per_tensor(torch.randn(1, 3, 4, 4), scale=0.1, zero_point=0, dtype=torch.quint8)
# 批归一化参数
weight = torch.ones(3)
bias = torch.zeros(3)
mean = torch.ones(3) * 0.5
var = torch.ones(3) * 0.2
# 执行量化批归一化
output = torch.quantized_batch_norm(
input, weight, bias, mean, var,
eps=1e-5, output_scale=0.1, output_zero_point=0
)
print("输出形状:", output.shape)
print("输出类型:", output.dtype)
# 创建量化输入张量
input = torch.quantize_per_tensor(torch.randn(1, 3, 4, 4), scale=0.1, zero_point=0, dtype=torch.quint8)
# 批归一化参数
weight = torch.ones(3)
bias = torch.zeros(3)
mean = torch.ones(3) * 0.5
var = torch.ones(3) * 0.2
# 执行量化批归一化
output = torch.quantized_batch_norm(
input, weight, bias, mean, var,
eps=1e-5, output_scale=0.1, output_zero_point=0
)
print("输出形状:", output.shape)
print("输出类型:", output.dtype)
输出结果为:
输出形状: torch.Size([1, 3, 4, 4]) 输出类型: torch.quint8

Pytorch torch 参考手册