现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.quantized_batch_norm 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.quantized_batch_norm 是 PyTorch 中用于对量化张量执行批归一化的函数。该函数在量化模型推理时非常有用,可以在保持量化优势的同时进行批归一化操作。

函数定义

torch.quantized_batch_norm(input, weight, bias, mean, var, eps, output_scale, output_zero_point)

参数说明

  • input: 输入的量化张量
  • weight: 批归一化的缩放参数
  • bias: 批归一化的偏置参数
  • mean: 批归一化的均值
  • var: 批归一化的方差
  • eps: 防止除零的小常数
  • output_scale: 输出张量的量化 scale
  • output_zero_point: 输出张量的量化 zero point

使用示例

实例

import torch

# 创建量化输入张量
input = torch.quantize_per_tensor(torch.randn(1, 3, 4, 4), scale=0.1, zero_point=0, dtype=torch.quint8)

# 批归一化参数
weight = torch.ones(3)
bias = torch.zeros(3)
mean = torch.ones(3) * 0.5
var = torch.ones(3) * 0.2

# 执行量化批归一化
output = torch.quantized_batch_norm(
    input, weight, bias, mean, var,
    eps=1e-5, output_scale=0.1, output_zero_point=0
)

print("输出形状:", output.shape)
print("输出类型:", output.dtype)

输出结果为:

输出形状: torch.Size([1, 3, 4, 4])
输出类型: torch.quint8

Pytorch torch 参考手册 Pytorch torch 参考手册