PyTorch torch.nn.LSTM 函数
torch.nn.LSTM 是 PyTorch 中用于长短期记忆网络的模块。
LSTM 是一种特殊的循环神经网络,能够学习长期依赖关系,广泛用于序列建模任务。
函数定义
torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)
参数说明:
input_size(int): 输入特征维度。hidden_size(int): 隐藏状态维度。num_layers(int): LSTM 层数。默认为 1。bias(bool): 是否使用偏置。默认为 True。batch_first(bool): 如果为 True,输入输出形状为 (batch, seq, feature)。默认为 True。dropout(float): 非最后一层使用的 dropout。默认为 0。bidirectional(bool): 是否使用双向 LSTM。默认为 False。
输入输出
输入:
input: 形状为 (batch, seq_len, input_size) 的张量h_0: 初始隐藏状态,形状为 (num_layers * num_directions, batch, hidden_size)c_0: 初始细胞状态,形状为 (num_layers * num_directions, batch, hidden_size)
输出:
output: 最后一个隐藏层的输出,形状为 (batch, seq_len, num_directions * hidden_size)h_n: 所有层的最后一个隐藏状态c_n: 所有层的最后一个细胞状态
使用示例
示例 1: 基本用法
创建并使用 LSTM:
实例
import torch
import torch.nn as nn
# 创建 LSTM:输入维度 256,隐藏维度 512,2 层
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)
# 创建输入:batch=4,序列长度=10,输入维度=256
input_tensor = torch.randn(4, 10, 256)
# 前向传播
output, (h_n, c_n) = lstm(input_tensor)
print("输入形状:", input_tensor.shape)
print("输出形状:", output.shape) # (4, 10, 512)
print("隐藏状态形状:", h_n.shape) # (2, 4, 512)
print("细胞状态形状:", c_n.shape) # (2, 4, 512)
import torch.nn as nn
# 创建 LSTM:输入维度 256,隐藏维度 512,2 层
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)
# 创建输入:batch=4,序列长度=10,输入维度=256
input_tensor = torch.randn(4, 10, 256)
# 前向传播
output, (h_n, c_n) = lstm(input_tensor)
print("输入形状:", input_tensor.shape)
print("输出形状:", output.shape) # (4, 10, 512)
print("隐藏状态形状:", h_n.shape) # (2, 4, 512)
print("细胞状态形状:", c_n.shape) # (2, 4, 512)
示例 2: 双向 LSTM
使用双向 LSTM 捕获双向上下文:
实例
import torch
import torch.nn as nn
# 双向 LSTM
bilstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = bilstm(input_tensor)
print("双向 LSTM 输出形状:", output.shape) # (4, 10, 512) = 256*2
print("隐藏状态形状:", h_n.shape) # (4, 4, 256) = 2层*2方向
print("最后一层隐藏状态:", h_n[-2:, :, :].shape) # 正向和反向
import torch.nn as nn
# 双向 LSTM
bilstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = bilstm(input_tensor)
print("双向 LSTM 输出形状:", output.shape) # (4, 10, 512) = 256*2
print("隐藏状态形状:", h_n.shape) # (4, 4, 256) = 2层*2方向
print("最后一层隐藏状态:", h_n[-2:, :, :].shape) # 正向和反向
示例 3: 初始化隐藏状态
手动初始化隐藏状态:
实例
import torch
import torch.nn as nn
lstm = nn.LSTM(input_size=256, hidden_size=512, batch_first=True)
# 手动创建初始隐藏状态
batch_size = 4
num_layers = 2
hidden_size = 512
h_0 = torch.zeros(num_layers, batch_size, hidden_size)
c_0 = torch.zeros(num_layers, batch_size, hidden_size)
# 传入初始状态
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = lstm(input_tensor, (h_0, c_0))
print("使用自定义初始状态完成")
print("输出形状:", output.shape)
import torch.nn as nn
lstm = nn.LSTM(input_size=256, hidden_size=512, batch_first=True)
# 手动创建初始隐藏状态
batch_size = 4
num_layers = 2
hidden_size = 512
h_0 = torch.zeros(num_layers, batch_size, hidden_size)
c_0 = torch.zeros(num_layers, batch_size, hidden_size)
# 传入初始状态
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = lstm(input_tensor, (h_0, c_0))
print("使用自定义初始状态完成")
print("输出形状:", output.shape)
示例 4: 完整的情感分类模型
基于 LSTM 的文本分类:
实例
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
super(LSTMClassifier, self).__init__()
# 嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM 层
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3
)
# 全连接分类层
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
# LSTM 输出
output, (hidden, cell) = self.lstm(embedded)
# 拼接双向最后一层的隐藏状态
# hidden: (4, batch, hidden_dim) - 2层 * 2方向
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) # (batch, hidden_dim*2)
# 分类
logits = self.fc(hidden)
return logits
# 实例化模型
vocab_size = 10000
model = LSTMClassifier(vocab_size=vocab_size, embed_dim=128, hidden_dim=128, num_classes=2)
# 测试输入:batch=8,序列长度=50
input_ids = torch.randint(1, vocab_size, (8, 50))
output = model(input_ids)
print("模型结构:")
print(model)
print("n输入形状:", input_ids.shape)
print("输出形状:", output.shape) # (8, 2)
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
super(LSTMClassifier, self).__init__()
# 嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM 层
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3
)
# 全连接分类层
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
# LSTM 输出
output, (hidden, cell) = self.lstm(embedded)
# 拼接双向最后一层的隐藏状态
# hidden: (4, batch, hidden_dim) - 2层 * 2方向
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) # (batch, hidden_dim*2)
# 分类
logits = self.fc(hidden)
return logits
# 实例化模型
vocab_size = 10000
model = LSTMClassifier(vocab_size=vocab_size, embed_dim=128, hidden_dim=128, num_classes=2)
# 测试输入:batch=8,序列长度=50
input_ids = torch.randint(1, vocab_size, (8, 50))
output = model(input_ids)
print("模型结构:")
print(model)
print("n输入形状:", input_ids.shape)
print("输出形状:", output.shape) # (8, 2)
示例 5: 多层堆叠 LSTM
深层 LSTM 网络:
实例
import torch
import torch.nn as nn
# 4 层堆叠的 LSTM,带 dropout
deep_lstm = nn.LSTM(
input_size=256,
hidden_size=512,
num_layers=4,
batch_first=True,
dropout=0.4 # 每层之间的 dropout
)
input_tensor = torch.randn(2, 20, 256)
output, (h_n, c_n) = deep_lstm(input_tensor)
print("4层 LSTM 输出形状:", output.shape)
print("隐藏状态形状 (4层):", h_n.shape)
print("细胞状态形状 (4层):", c_n.shape)
import torch.nn as nn
# 4 层堆叠的 LSTM,带 dropout
deep_lstm = nn.LSTM(
input_size=256,
hidden_size=512,
num_layers=4,
batch_first=True,
dropout=0.4 # 每层之间的 dropout
)
input_tensor = torch.randn(2, 20, 256)
output, (h_n, c_n) = deep_lstm(input_tensor)
print("4层 LSTM 输出形状:", output.shape)
print("隐藏状态形状 (4层):", h_n.shape)
print("细胞状态形状 (4层):", c_n.shape)
LSTM 门的概念
LSTM 通过三个门控制信息流:
- 遗忘门: 决定保留多少上一时刻的信息
- 输入门: 决定加入多少新信息
- 输出门: 决定输出多少信息
常见问题
Q1: batch_first=True 是什么意思?
输入输出张量的第一维是 batch_size。如果为 False,则第一维是序列长度。
Q2: 双向 LSTM 什么时候用?
序列标注、情感分析等需要双向上下文的任务。机器翻译常用encoder-decoder架构。
Q3: 如何选择隐藏层大小?
通常 128-512,根据任务复杂度和数据量调整。太小欠拟合,太大过拟合。
使用场景
nn.LSTM 主要应用场景包括:
- 自然语言处理: 文本分类、命名实体识别
- 时间序列预测: 股票预测、语音识别
- 序列到序列任务: 机器翻译、文本生成
提示:使用 bidirectional=True 时,输出维度会变成 hidden_size * 2。

PyTorch torch.nn 参考手册