PyTorch torch.nn.GRU 函数
torch.nn.GRU 是 PyTorch 中的门控循环单元模块。
GRU 是 LSTM 的简化版本,参数更少,计算更快,效果相近。
函数定义
torch.nn.GRU(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# GRU: 输入256维,隐藏256维,2层
gru = nn.GRU(input_size=256, hidden_size=256, num_layers=2, batch_first=True)
# 输入:batch=4,序列=10,特征=256
x = torch.randn(4, 10, 256)
output, hidden = gru(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("隐藏状态形状:", hidden.shape)
import torch.nn as nn
# GRU: 输入256维,隐藏256维,2层
gru = nn.GRU(input_size=256, hidden_size=256, num_layers=2, batch_first=True)
# 输入:batch=4,序列=10,特征=256
x = torch.randn(4, 10, 256)
output, hidden = gru(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("隐藏状态形状:", hidden.shape)
示例 2: 对比 LSTM
实例
import torch
import torch.nn as nn
import time
# 相同配置的 LSTM 和 GRU
lstm = nn.LSTM(256, 256, 1, batch_first=True)
gru = nn.GRU(256, 256, 1, batch_first=True)
x = torch.randn(32, 100, 256)
# 性能对比
for model, name in [(lstm, "LSTM"), (gru, "GRU")]:
start = time.time()
for _ in range(100):
_ = model(x)
print(f"{name} 时间: {time.time()-start:.3f}s")
import torch.nn as nn
import time
# 相同配置的 LSTM 和 GRU
lstm = nn.LSTM(256, 256, 1, batch_first=True)
gru = nn.GRU(256, 256, 1, batch_first=True)
x = torch.randn(32, 100, 256)
# 性能对比
for model, name in [(lstm, "LSTM"), (gru, "GRU")]:
start = time.time()
for _ in range(100):
_ = model(x)
print(f"{name} 时间: {time.time()-start:.3f}s")
示例 3: 分类任务
实例
import torch
import torch.nn as nn
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super(GRUClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
embedded = self.embedding(x)
_, hidden = self.gru(embedded)
# 拼接双向最后隐藏状态
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
return self.fc(hidden)
model = GRUClassifier(10000, 128, 128, 2)
x = torch.randint(0, 10000, (8, 50))
output = model(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super(GRUClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
embedded = self.embedding(x)
_, hidden = self.gru(embedded)
# 拼接双向最后隐藏状态
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
return self.fc(hidden)
model = GRUClassifier(10000, 128, 128, 2)
x = torch.randint(0, 10000, (8, 50))
output = model(x)
print("输入:", x.shape, "-> 输出:", output.shape)
LSTM vs GRU
| 方面 | LSTM | GRU |
|---|---|---|
| 参数量 | 较多 | 较少 |
| 门控 | 3个门 | 2个门 |
| 计算 | 较慢 | 较快 |
使用场景
- 序列建模: 文本、音频
- 快速原型: 资源有限时
- 机器翻译: encoder 端

PyTorch torch.nn 参考手册