现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.GRU 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.GRU 是 PyTorch 中的门控循环单元模块。

GRU 是 LSTM 的简化版本,参数更少,计算更快,效果相近。

函数定义

torch.nn.GRU(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# GRU: 输入256维,隐藏256维,2层
gru = nn.GRU(input_size=256, hidden_size=256, num_layers=2, batch_first=True)

# 输入:batch=4,序列=10,特征=256
x = torch.randn(4, 10, 256)

output, hidden = gru(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("隐藏状态形状:", hidden.shape)

示例 2: 对比 LSTM

实例

import torch
import torch.nn as nn
import time

# 相同配置的 LSTM 和 GRU
lstm = nn.LSTM(256, 256, 1, batch_first=True)
gru = nn.GRU(256, 256, 1, batch_first=True)

x = torch.randn(32, 100, 256)

# 性能对比
for model, name in [(lstm, "LSTM"), (gru, "GRU")]:
    start = time.time()
    for _ in range(100):
        _ = model(x)
    print(f"{name} 时间: {time.time()-start:.3f}s")

示例 3: 分类任务

实例

import torch
import torch.nn as nn

class GRUClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super(GRUClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):
        embedded = self.embedding(x)
        _, hidden = self.gru(embedded)
        # 拼接双向最后隐藏状态
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        return self.fc(hidden)

model = GRUClassifier(10000, 128, 128, 2)
x = torch.randint(0, 10000, (8, 50))
output = model(x)

print("输入:", x.shape, "-> 输出:", output.shape)

LSTM vs GRU

方面 LSTM GRU
参数量 较多 较少
门控 3个门 2个门
计算 较慢 较快

使用场景

  • 序列建模: 文本、音频
  • 快速原型: 资源有限时
  • 机器翻译: encoder 端

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册