PyTorch torch.nn.RNN 函数
torch.nn.RNN 是 PyTorch 中的基础循环神经网络模块。
它是最简单的循环层,但容易面临梯度消失问题。
函数定义
torch.nn.RNN(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
rnn = nn.RNN(input_size=256, hidden_size=128, num_layers=2, batch_first=True)
x = torch.randn(4, 10, 256)
output, hidden = rnn(x)
print("输入:", x.shape)
print("输出:", output.shape)
print("隐藏:", hidden.shape)
import torch.nn as nn
rnn = nn.RNN(input_size=256, hidden_size=128, num_layers=2, batch_first=True)
x = torch.randn(4, 10, 256)
output, hidden = rnn(x)
print("输入:", x.shape)
print("输出:", output.shape)
print("隐藏:", hidden.shape)
示例 2: 多层 RNN
实例
import torch
import torch.nn as nn
# 3层 RNN,带 dropout
rnn = nn.RNN(128, 256, num_layers=3, dropout=0.3, batch_first=True)
x = torch.randn(2, 50, 128)
out, h = rnn(x)
print("输入:", x.shape)
print("输出:", out.shape)
print("隐藏:", h.shape) # (3, 2, 256)
import torch.nn as nn
# 3层 RNN,带 dropout
rnn = nn.RNN(128, 256, num_layers=3, dropout=0.3, batch_first=True)
x = torch.randn(2, 50, 128)
out, h = rnn(x)
print("输入:", x.shape)
print("输出:", out.shape)
print("隐藏:", h.shape) # (3, 2, 256)
示例 3: 非线性激活
实例
import torch
import torch.nn as nn
# 默认使用 tanh 作为激活函数
rnn = nn.RNN(64, 64, batch_first=True)
x = torch.randn(1, 5, 64)
out, _ = rnn(x)
print("输出形状:", out.shape)
print("RNN 默认使用 tanh 激活")
import torch.nn as nn
# 默认使用 tanh 作为激活函数
rnn = nn.RNN(64, 64, batch_first=True)
x = torch.randn(1, 5, 64)
out, _ = rnn(x)
print("输出形状:", out.shape)
print("RNN 默认使用 tanh 激活")
注意事项
基础 RNN 存在梯度消失问题,长序列建议使用 LSTM 或 GRU。
使用场景
- 简单序列任务: 短序列
- 教学示例: 理解 RNN 原理
- 快速原型: 简单基线

PyTorch torch.nn 参考手册