现在位置: 首页 > PyTorch 教程 > 正文

PyTorch nn.MaxPool2d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.MaxPool2d 是 PyTorch 中用于二维最大池化的模块。

池化层可以降低特征图的空间尺寸,减少计算量,同时提供一定程度的平移不变性。

函数定义

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

参数说明:

  • kernel_size (int 或 tuple): 池化窗口的大小。
  • stride (int 或 tuple): 池化窗口移动的步长。默认为 kernel_size
  • padding (int 或 tuple): 输入边缘的填充大小。默认为 0。
  • dilation (int 或 tuple): 窗口元素的间距。默认为 1。
  • return_indices (bool): 是否返回最大值的索引。用于 MaxUnpool2d。默认为 False。
  • ceil_mode (bool): 是否使用 ceil 而非 floor 计算输出尺寸。默认为 False。

使用示例

示例 1: 基本用法

创建一个最大池化层:

实例

import torch
import torch.nn as nn

# 创建最大池化层:窗口 2x2,步长 2
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

# 创建输入张量
input_tensor = torch.randn(1, 1, 4, 4)
print("输入:n", input_tensor.squeeze().tolist())

# 前向传播
output = max_pool(input_tensor)

print("n输出:n", output.squeeze().tolist())
print("n输入形状:", input_tensor.shape)
print("输出形状:", output.shape)

输出结果为:

输入:
[[-0.4128, 0.2341, -0.9876, 0.4567],
 [ 0.1234, 0.8765, -0.2345, 0.6789],
 [-0.5678, 0.3456, 0.7890, -0.1234],
 [ 0.9012, -0.4567, 0.2345, 0.5678]]

输出:
[[0.8765, 0.6789],
 [0.9012, 0.7890]]

输入形状: torch.Size([1, 1, 4, 4])
输出形状: torch.Size([1, 1, 2, 2])

可以看到,每个 2x2 窗口中的最大值被保留下来。

示例 2: 不同的 kernel_size 和 stride

调整池化参数:

实例

import torch
import torch.nn as nn

# 3x3 池化,步长为 1(不重叠)
pool3x3 = nn.MaxPool2d(kernel_size=3, stride=1)

# 非方形池化
pool_rect = nn.MaxPool2d(kernel_size=(2, 3), stride=(2, 3))

input_tensor = torch.randn(1, 1, 6, 9)

print("输入形状:", input_tensor.shape)
print("3x3 池化输出:", pool3x3(input_tensor).shape)
print("矩形池化输出:", pool_rect(input_tensor).shape)

示例 3: 使用 padding

边缘填充可以在一定程度上保留边缘信息:

实例

import torch
import torch.nn as nn

# 带 padding 的池化
pool_padding = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

input_tensor = torch.randn(1, 1, 4, 4)
output = pool_padding(input_tensor)

print("输入形状:", input_tensor.shape)
print("输出形状 (带 padding):", output.shape)

示例 4: 返回索引

使用 return_indices 可以在解码器中恢复位置:

实例

import torch
import torch.nn as nn

# 创建返回索引的池化层
pool_indices = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

input_tensor = torch.randn(1, 1, 4, 4)
output, indices = pool_indices(input_tensor)

print("输出形状:", output.shape)
print("索引形状:", indices.shape)
print("索引值:", indices.squeeze().tolist())

示例 5: 在 CNN 中使用

典型的 CNN 结构中池化层的位置:

实例

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 卷积层
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

        # 池化层:每经过一次,尺寸减半
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.relu(self.conv1(x))  # 32x32
        x = self.pool(x)              # 16x16

        x = self.relu(self.conv2(x))  # 16x16
        x = self.pool(x)              # 8x8

        return x

model = SimpleCNN()
input_image = torch.randn(1, 3, 32, 32)
output = model(input_image)

print("输入形状:", input_image.shape)
print("输出形状:", output.shape)

最大池化 vs 平均池化

类型 公式 特点 适用场景
MaxPool2d max(区域) 保留显著特征,对噪声更鲁棒 图像分类、目标检测(常用)
AvgPool2d mean(区域) 平滑特征,保留背景信息 全局平均池化、特征提取

常见问题

Q1: 池化层是否可以移除?

现代网络如 ResNet、DenseNet 倾向于使用较小的 stride 卷积代替池化,但池化仍常用于快速下采样。

Q2: stride 和 kernel_size 有什么关系?

当 stride = kernel_size 时,池化窗口不重叠;当 stride < kernel_size 时,池化窗口有重叠。


使用场景

nn.MaxPool2d 主要应用场景包括:

  • 图像分类网络: 逐步降低分辨率,提取高级特征
  • 目标检测: 保留显著特征位置
  • 降低计算量: 减少特征图尺寸
  • 增加感受野: 让后续层看到更大范围的特征

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册