现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.AdaptiveMaxPool2d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.AdaptiveMaxPool2d 是 PyTorch 中的自适应最大池化模块。

它将输入池化到指定尺寸,保留最大值而非平均值。

函数定义

torch.nn.AdaptiveMaxPool2d(output_size, return_indices=False)

参数

  • output_size: 输出尺寸
  • return_indices: 是否返回索引

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 全局最大池化
gap = nn.AdaptiveMaxPool2d(1)

# 输入不同尺寸
x1 = torch.randn(1, 64, 32, 32)
x2 = torch.randn(1, 64, 16, 16)

print("32x32 ->", gap(x1).shape)
print("16x16 ->", gap(x2).shape)

示例 2: 返回索引

实例

import torch
import torch.nn as nn

gap = nn.AdaptiveMaxPool2d(1, return_indices=True)

x = torch.randn(1, 64, 8, 8)
output, indices = gap(x)

print("输出形状:", output.shape)
print("索引形状:", indices.shape)
print("索引值:", indices.item())

示例 3: 与 AdaptiveAvgPool2d 对比

实例

import torch
import torch.nn as nn

x = torch.tensor([[[
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16]
]]], dtype=torch.float32)

avgpool = nn.AdaptiveAvgPool2d(1)
maxpool = nn.AdaptiveMaxPool2d(1)

print("输入:n", x[0, 0])
print("平均池化:", avgpool(x).item())
print("最大池化:", maxpool(x).item())

使用场景

  • 全局最大池化: 提取最显著特征
  • 特征汇聚: 保留关键信息
  • 分类网络: 替代 FC 层

提示:最大池化保留最显著特征,平均池化更平滑。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册