PyTorch torch.nn.AdaptiveAvgPool2d 函数
torch.nn.AdaptiveAvgPool2d 是 PyTorch 中的自适应平均池化模块。
它可以将任意尺寸的输入特征图池化到指定的目标尺寸,无需手动计算 kernel_size 和 stride。
函数定义
torch.nn.AdaptiveAvgPool2d(output_size)
参数说明:
output_size(int 或 tuple): 输出尺寸。可以是 (H, W) 或单个 int(正方形)。
使用示例
示例 1: 基本用法
将特征图池化到固定尺寸:
实例
import torch
import torch.nn as nn
# 池化到 1x1
adaptive_pool = nn.AdaptiveAvgPool2d(1)
# 输入不同尺寸
x1 = torch.randn(1, 64, 32, 32)
x2 = torch.randn(1, 64, 16, 16)
x3 = torch.randn(1, 64, 8, 8)
out1 = adaptive_pool(x1)
out2 = adaptive_pool(x2)
out3 = adaptive_pool(x3)
print("32x32 ->", out1.shape)
print("16x16 ->", out2.shape)
print("8x8 ->", out3.shape)
print("n所有输出都被池化为 1x1")
import torch.nn as nn
# 池化到 1x1
adaptive_pool = nn.AdaptiveAvgPool2d(1)
# 输入不同尺寸
x1 = torch.randn(1, 64, 32, 32)
x2 = torch.randn(1, 64, 16, 16)
x3 = torch.randn(1, 64, 8, 8)
out1 = adaptive_pool(x1)
out2 = adaptive_pool(x2)
out3 = adaptive_pool(x3)
print("32x32 ->", out1.shape)
print("16x16 ->", out2.shape)
print("8x8 ->", out3.shape)
print("n所有输出都被池化为 1x1")
示例 2: 输出任意尺寸
池化到非正方形:
实例
import torch
import torch.nn as nn
# 池化到 4x4
pool = nn.AdaptiveAvgPool2d((4, 4))
x = torch.randn(2, 128, 32, 32)
out = pool(x)
print("输入形状:", x.shape)
print("输出形状:", out.shape) # (2, 128, 4, 4)
# 池化到 1x7(可用于序列)
pool_seq = nn.AdaptiveAvgPool2d((1, 7))
x_seq = torch.randn(2, 64, 10, 20)
out_seq = pool_seq(x_seq)
print("n序列池化:", x_seq.shape, "->", out_seq.shape)
import torch.nn as nn
# 池化到 4x4
pool = nn.AdaptiveAvgPool2d((4, 4))
x = torch.randn(2, 128, 32, 32)
out = pool(x)
print("输入形状:", x.shape)
print("输出形状:", out.shape) # (2, 128, 4, 4)
# 池化到 1x7(可用于序列)
pool_seq = nn.AdaptiveAvgPool2d((1, 7))
x_seq = torch.randn(2, 64, 10, 20)
out_seq = pool_seq(x_seq)
print("n序列池化:", x_seq.shape, "->", out_seq.shape)
示例 3: 在 CNN 中使用
典型的全局平均池化:
实例
import torch
import torch.nn as nn
class CNNWithGAP(nn.Module):
def __init__(self, num_classes=10):
super(CNNWithGAP, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 全局平均池化
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(256, num_classes)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = CNNWithGAP()
x = torch.randn(4, 3, 32, 32)
output = model(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("n特征图被全局池化为单一向量")
import torch.nn as nn
class CNNWithGAP(nn.Module):
def __init__(self, num_classes=10):
super(CNNWithGAP, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 全局平均池化
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(256, num_classes)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = CNNWithGAP()
x = torch.randn(4, 3, 32, 32)
output = model(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("n特征图被全局池化为单一向量")
示例 4: 与普通池化对比
实例
import torch
import torch.nn as nn
# 自适应池化 vs 普通池化
adaptive = nn.AdaptiveAvgPool2d(2)
regular = nn.AvgPool2d(kernel_size=16, stride=16)
x = torch.randn(1, 64, 32, 32)
out_adaptive = adaptive(x)
out_regular = regular(x)
print("输入:", x.shape)
print("自适应池化 (输出2x2):", out_adaptive.shape)
print("普通平均池化 (16x16):", out_regular.shape)
import torch.nn as nn
# 自适应池化 vs 普通池化
adaptive = nn.AdaptiveAvgPool2d(2)
regular = nn.AvgPool2d(kernel_size=16, stride=16)
x = torch.randn(1, 64, 32, 32)
out_adaptive = adaptive(x)
out_regular = regular(x)
print("输入:", x.shape)
print("自适应池化 (输出2x2):", out_adaptive.shape)
print("普通平均池化 (16x16):", out_regular.shape)
自适应池化 vs 普通池化
| 类型 | 优点 | 适用场景 |
|---|---|---|
| AdaptiveAvgPool | 输入尺寸任意 | 不同尺度输入 |
| AvgPool | 计算更快 | 固定输入尺寸 |
常见问题
Q1: 何时使用自适应池化?
当输入尺寸不固定,或需要统一到固定尺寸时。
Q2: Global Average Pooling 是什么?
AdaptiveAvgPool2d(1),将每个特征图池化为单个值。
使用场景
- 分类网络: 全局平均池化替代 FC 层
- 多尺度输入: 适应不同尺寸图片
- 特征汇聚: 提取关键信息
提示:全局平均池化可以显著减少参数量。

PyTorch torch.nn 参考手册