现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.RNN 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.RNN 是 PyTorch 中的基础循环神经网络模块。

它是最简单的循环层,但容易面临梯度消失问题。

函数定义

torch.nn.RNN(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

rnn = nn.RNN(input_size=256, hidden_size=128, num_layers=2, batch_first=True)

x = torch.randn(4, 10, 256)
output, hidden = rnn(x)

print("输入:", x.shape)
print("输出:", output.shape)
print("隐藏:", hidden.shape)

示例 2: 多层 RNN

实例

import torch
import torch.nn as nn

# 3层 RNN,带 dropout
rnn = nn.RNN(128, 256, num_layers=3, dropout=0.3, batch_first=True)

x = torch.randn(2, 50, 128)
out, h = rnn(x)

print("输入:", x.shape)
print("输出:", out.shape)
print("隐藏:", h.shape)  # (3, 2, 256)

示例 3: 非线性激活

实例

import torch
import torch.nn as nn

# 默认使用 tanh 作为激活函数
rnn = nn.RNN(64, 64, batch_first=True)

x = torch.randn(1, 5, 64)
out, _ = rnn(x)

print("输出形状:", out.shape)
print("RNN 默认使用 tanh 激活")

注意事项

基础 RNN 存在梯度消失问题,长序列建议使用 LSTM 或 GRU。


使用场景

  • 简单序列任务: 短序列
  • 教学示例: 理解 RNN 原理
  • 快速原型: 简单基线

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册