PyTorch torch.nn.Flatten 函数
torch.nn.Flatten 是 PyTorch 中的张量展平模块。
它将多维张量展平为一维,常用于卷积层和全连接层之间的连接。
函数定义
torch.nn.Flatten(start_dim=1, end_dim=-1)
参数说明:
start_dim(int): 展平开始的维度。默认为 1(保留 batch 维度)。end_dim(int): 展平结束的维度。默认为 -1(到最后一维)。
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
flatten = nn.Flatten()
# 4D 输入: (batch, channels, height, width)
x = torch.randn(4, 3, 32, 32)
output = flatten(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("展平后: 3*32*32 = 3072 维")
import torch.nn as nn
flatten = nn.Flatten()
# 4D 输入: (batch, channels, height, width)
x = torch.randn(4, 3, 32, 32)
output = flatten(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("展平后: 3*32*32 = 3072 维")
示例 2: 保留 batch 维度
实例
import torch
import torch.nn as nn
# start_dim=1 保留 batch 维度
x = torch.randn(8, 64, 8, 8)
print("输入:", x.shape)
# 展平到 (8, 4096)
out1 = nn.Flatten(start_dim=1)(x)
print("从维度1开始:", out1.shape)
# 不保留 batch
out2 = nn.Flatten(start_dim=0)(x)
print("从维度0开始:", out2.shape)
import torch.nn as nn
# start_dim=1 保留 batch 维度
x = torch.randn(8, 64, 8, 8)
print("输入:", x.shape)
# 展平到 (8, 4096)
out1 = nn.Flatten(start_dim=1)(x)
print("从维度1开始:", out1.shape)
# 不保留 batch
out2 = nn.Flatten(start_dim=0)(x)
print("从维度0开始:", out2.shape)
示例 3: 3D 输入
实例
import torch
import torch.nn as nn
# 3D 输入: (batch, seq_len, features)
x = torch.randn(4, 100, 512)
# 展平序列和特征
flatten = nn.Flatten(start_dim=1)
output = flatten(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
import torch.nn as nn
# 3D 输入: (batch, seq_len, features)
x = torch.randn(4, 100, 512)
# 展平序列和特征
flatten = nn.Flatten(start_dim=1)
output = flatten(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
示例 4: 完整 CNN 示例
实例
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 10)
)
x = torch.randn(4, 3, 32, 32)
output = model(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 10)
)
x = torch.randn(4, 3, 32, 32)
output = model(x)
print("输入:", x.shape, "-> 输出:", output.shape)
使用场景
- CNN 到 FC: 卷积层输出展平后连接全连接层
- 维度变换: 调整张量形状

PyTorch torch.nn 参考手册