现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.AdaptiveAvgPool2d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.AdaptiveAvgPool2d 是 PyTorch 中的自适应平均池化模块。

它可以将任意尺寸的输入特征图池化到指定的目标尺寸,无需手动计算 kernel_size 和 stride。

函数定义

torch.nn.AdaptiveAvgPool2d(output_size)

参数说明:

  • output_size (int 或 tuple): 输出尺寸。可以是 (H, W) 或单个 int(正方形)。

使用示例

示例 1: 基本用法

将特征图池化到固定尺寸:

实例

import torch
import torch.nn as nn

# 池化到 1x1
adaptive_pool = nn.AdaptiveAvgPool2d(1)

# 输入不同尺寸
x1 = torch.randn(1, 64, 32, 32)
x2 = torch.randn(1, 64, 16, 16)
x3 = torch.randn(1, 64, 8, 8)

out1 = adaptive_pool(x1)
out2 = adaptive_pool(x2)
out3 = adaptive_pool(x3)

print("32x32 ->", out1.shape)
print("16x16 ->", out2.shape)
print("8x8   ->", out3.shape)
print("n所有输出都被池化为 1x1")

示例 2: 输出任意尺寸

池化到非正方形:

实例

import torch
import torch.nn as nn

# 池化到 4x4
pool = nn.AdaptiveAvgPool2d((4, 4))

x = torch.randn(2, 128, 32, 32)
out = pool(x)

print("输入形状:", x.shape)
print("输出形状:", out.shape)  # (2, 128, 4, 4)

# 池化到 1x7(可用于序列)
pool_seq = nn.AdaptiveAvgPool2d((1, 7))
x_seq = torch.randn(2, 64, 10, 20)
out_seq = pool_seq(x_seq)
print("n序列池化:", x_seq.shape, "->", out_seq.shape)

示例 3: 在 CNN 中使用

典型的全局平均池化:

实例

import torch
import torch.nn as nn

class CNNWithGAP(nn.Module):
    def __init__(self, num_classes=10):
        super(CNNWithGAP, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        # 全局平均池化
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = CNNWithGAP()
x = torch.randn(4, 3, 32, 32)
output = model(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("n特征图被全局池化为单一向量")

示例 4: 与普通池化对比

实例

import torch
import torch.nn as nn

# 自适应池化 vs 普通池化
adaptive = nn.AdaptiveAvgPool2d(2)
regular = nn.AvgPool2d(kernel_size=16, stride=16)

x = torch.randn(1, 64, 32, 32)

out_adaptive = adaptive(x)
out_regular = regular(x)

print("输入:", x.shape)
print("自适应池化 (输出2x2):", out_adaptive.shape)
print("普通平均池化 (16x16):", out_regular.shape)

自适应池化 vs 普通池化

类型 优点 适用场景
AdaptiveAvgPool 输入尺寸任意 不同尺度输入
AvgPool 计算更快 固定输入尺寸

常见问题

Q1: 何时使用自适应池化?

当输入尺寸不固定,或需要统一到固定尺寸时。

Q2: Global Average Pooling 是什么?

AdaptiveAvgPool2d(1),将每个特征图池化为单个值。


使用场景

  • 分类网络: 全局平均池化替代 FC 层
  • 多尺度输入: 适应不同尺寸图片
  • 特征汇聚: 提取关键信息

提示:全局平均池化可以显著减少参数量。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册