现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Flatten 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Flatten 是 PyTorch 中的张量展平模块。

它将多维张量展平为一维,常用于卷积层和全连接层之间的连接。

函数定义

torch.nn.Flatten(start_dim=1, end_dim=-1)

参数说明:

  • start_dim (int): 展平开始的维度。默认为 1(保留 batch 维度)。
  • end_dim (int): 展平结束的维度。默认为 -1(到最后一维)。

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

flatten = nn.Flatten()

# 4D 输入: (batch, channels, height, width)
x = torch.randn(4, 3, 32, 32)

output = flatten(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("展平后: 3*32*32 = 3072 维")

示例 2: 保留 batch 维度

实例

import torch
import torch.nn as nn

# start_dim=1 保留 batch 维度
x = torch.randn(8, 64, 8, 8)
print("输入:", x.shape)

# 展平到 (8, 4096)
out1 = nn.Flatten(start_dim=1)(x)
print("从维度1开始:", out1.shape)

# 不保留 batch
out2 = nn.Flatten(start_dim=0)(x)
print("从维度0开始:", out2.shape)

示例 3: 3D 输入

实例

import torch
import torch.nn as nn

# 3D 输入: (batch, seq_len, features)
x = torch.randn(4, 100, 512)

# 展平序列和特征
flatten = nn.Flatten(start_dim=1)
output = flatten(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 4: 完整 CNN 示例

实例

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10)
)

x = torch.randn(4, 3, 32, 32)
output = model(x)

print("输入:", x.shape, "-> 输出:", output.shape)

使用场景

  • CNN 到 FC: 卷积层输出展平后连接全连接层
  • 维度变换: 调整张量形状

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册