现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.AvgPool2d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.AvgPool2d 是 PyTorch 中的二维平均池化模块。

它对输入的每个窗口计算平均值,常用于下采样和特征汇聚。

函数定义

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True)

参数说明

  • kernel_size: 池化窗口大小
  • stride: 步长,默认为 kernel_size
  • padding: 填充大小
  • ceil_mode: 是否使用 ceil 计算输出尺寸

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

pool = nn.AvgPool2d(kernel_size=2, stride=2)

x = torch.randn(1, 1, 4, 4)
output = pool(x)

print("输入:n", x.squeeze().tolist())
print("n输出:n", output.squeeze().tolist())
print("形状:", x.shape, "->", output.shape)

示例 2: 全局平均池化

实例

import torch
import torch.nn as nn

# 全局平均池化
gap = nn.AdaptiveAvgPool2d(1)

x = torch.randn(4, 64, 16, 16)
out = gap(x)

print("输入:", x.shape)
print("输出:", out.shape)
print("每个通道的平均值数量:", out.numel() // 64)

示例 3: 对比 MaxPool

实例

import torch
import torch.nn as nn

x = torch.tensor([[[
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16]
]]], dtype=torch.float32)

maxpool = nn.MaxPool2d(2, 2)
avgpool = nn.AvgPool2d(2, 2)

print("输入:n", x[0, 0])
print("nMaxPool:n", maxpool(x)[0, 0])
print("nAvgPool:n", avgpool(x)[0, 0])

使用场景

  • 特征汇聚: 减少空间维度
  • 全局平均池化: 替代全连接层
  • 平滑特征减少噪声

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册