现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.LSTM 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.LSTM 是 PyTorch 中用于长短期记忆网络的模块。

LSTM 是一种特殊的循环神经网络,能够学习长期依赖关系,广泛用于序列建模任务。

函数定义

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)

参数说明:

  • input_size (int): 输入特征维度。
  • hidden_size (int): 隐藏状态维度。
  • num_layers (int): LSTM 层数。默认为 1。
  • bias (bool): 是否使用偏置。默认为 True。
  • batch_first (bool): 如果为 True,输入输出形状为 (batch, seq, feature)。默认为 True。
  • dropout (float): 非最后一层使用的 dropout。默认为 0。
  • bidirectional (bool): 是否使用双向 LSTM。默认为 False。

输入输出

输入:

  • input: 形状为 (batch, seq_len, input_size) 的张量
  • h_0: 初始隐藏状态,形状为 (num_layers * num_directions, batch, hidden_size)
  • c_0: 初始细胞状态,形状为 (num_layers * num_directions, batch, hidden_size)

输出:

  • output: 最后一个隐藏层的输出,形状为 (batch, seq_len, num_directions * hidden_size)
  • h_n: 所有层的最后一个隐藏状态
  • c_n: 所有层的最后一个细胞状态

使用示例

示例 1: 基本用法

创建并使用 LSTM:

实例

import torch
import torch.nn as nn

# 创建 LSTM:输入维度 256,隐藏维度 512,2 层
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)

# 创建输入:batch=4,序列长度=10,输入维度=256
input_tensor = torch.randn(4, 10, 256)

# 前向传播
output, (h_n, c_n) = lstm(input_tensor)

print("输入形状:", input_tensor.shape)
print("输出形状:", output.shape)      # (4, 10, 512)
print("隐藏状态形状:", h_n.shape)     # (2, 4, 512)
print("细胞状态形状:", c_n.shape)     # (2, 4, 512)

示例 2: 双向 LSTM

使用双向 LSTM 捕获双向上下文:

实例

import torch
import torch.nn as nn

# 双向 LSTM
bilstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)

input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = bilstm(input_tensor)

print("双向 LSTM 输出形状:", output.shape)   # (4, 10, 512) = 256*2
print("隐藏状态形状:", h_n.shape)            # (4, 4, 256) = 2层*2方向
print("最后一层隐藏状态:", h_n[-2:, :, :].shape)  # 正向和反向

示例 3: 初始化隐藏状态

手动初始化隐藏状态:

实例

import torch
import torch.nn as nn

lstm = nn.LSTM(input_size=256, hidden_size=512, batch_first=True)

# 手动创建初始隐藏状态
batch_size = 4
num_layers = 2
hidden_size = 512

h_0 = torch.zeros(num_layers, batch_size, hidden_size)
c_0 = torch.zeros(num_layers, batch_size, hidden_size)

# 传入初始状态
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = lstm(input_tensor, (h_0, c_0))

print("使用自定义初始状态完成")
print("输出形状:", output.shape)

示例 4: 完整的情感分类模型

基于 LSTM 的文本分类:

实例

import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
        super(LSTMClassifier, self).__init__()
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # LSTM 层
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )
        # 全连接分类层
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)

        # LSTM 输出
        output, (hidden, cell) = self.lstm(embedded)

        # 拼接双向最后一层的隐藏状态
        # hidden: (4, batch, hidden_dim) - 2层 * 2方向
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)  # (batch, hidden_dim*2)

        # 分类
        logits = self.fc(hidden)
        return logits

# 实例化模型
vocab_size = 10000
model = LSTMClassifier(vocab_size=vocab_size, embed_dim=128, hidden_dim=128, num_classes=2)

# 测试输入:batch=8,序列长度=50
input_ids = torch.randint(1, vocab_size, (8, 50))
output = model(input_ids)

print("模型结构:")
print(model)
print("n输入形状:", input_ids.shape)
print("输出形状:", output.shape)  # (8, 2)

示例 5: 多层堆叠 LSTM

深层 LSTM 网络:

实例

import torch
import torch.nn as nn

# 4 层堆叠的 LSTM,带 dropout
deep_lstm = nn.LSTM(
    input_size=256,
    hidden_size=512,
    num_layers=4,
    batch_first=True,
    dropout=0.4  # 每层之间的 dropout
)

input_tensor = torch.randn(2, 20, 256)
output, (h_n, c_n) = deep_lstm(input_tensor)

print("4层 LSTM 输出形状:", output.shape)
print("隐藏状态形状 (4层):", h_n.shape)
print("细胞状态形状 (4层):", c_n.shape)

LSTM 门的概念

LSTM 通过三个门控制信息流:

  • 遗忘门: 决定保留多少上一时刻的信息
  • 输入门: 决定加入多少新信息
  • 输出门: 决定输出多少信息

常见问题

Q1: batch_first=True 是什么意思?

输入输出张量的第一维是 batch_size。如果为 False,则第一维是序列长度。

Q2: 双向 LSTM 什么时候用?

序列标注、情感分析等需要双向上下文的任务。机器翻译常用encoder-decoder架构。

Q3: 如何选择隐藏层大小?

通常 128-512,根据任务复杂度和数据量调整。太小欠拟合,太大过拟合。


使用场景

nn.LSTM 主要应用场景包括:

  • 自然语言处理: 文本分类、命名实体识别
  • 时间序列预测: 股票预测、语音识别
  • 序列到序列任务: 机器翻译、文本生成

提示:使用 bidirectional=True 时,输出维度会变成 hidden_size * 2。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册