现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.GELU 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.GELU 是 PyTorch 中的高斯误差线性单元(Gaussian Error Linear Unit)激活函数。

它是 Transformer 架构的默认激活函数,相比 ReLU 具有更好的性能和平滑的梯度。

函数定义

torch.nn.GELU(approximate='none')

参数说明:

  • approximate (str): 近似算法。可选 'none''tanh'。默认为 'none'

数学原理

GELU 的数学公式:

GELU(x) = x * Φ(x)

其中 Φ(x) 是标准正态分布的累积分布函数(CDF)。

使用 tanh 近似时:

GELU(x) ≈ 0.5x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))

使用示例

示例 1: 基本用法

创建并使用 GELU 激活:

实例

import torch
import torch.nn as nn

# 创建 GELU 激活层
gelu = nn.GELU()

# 测试输入
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])

# 前向传播
output = gelu(x)

print("输入:", x.tolist())
print("输出:", output.tolist())
print("n观察:负值有轻微激活(非零),正值保持增长")

示例 2: 对比不同激活函数

比较 GELU、ReLU、Sigmoid:

实例

import torch
import torch.nn as nn

x = torch.linspace(-4, 4, 21)

# 不同激活函数
gelu = nn.GELU()
relu = nn.ReLU()
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()

print("x       GELU      ReLU      Sigmoid   Tanh")
print("-" * 50)
for i in range(0, 21, 3):
    xi = x[i:i+3]
    print(f"{xi[0]:6.2f} {gelu(xi)[0]:8.4f} {relu(xi)[0]:8.4f} {sigmoid(xi)[0]:8.4f} {tanh(xi)[0]:8.4f}")

示例 3: 在 Transformer 中使用

典型的 Transformer FFN 层:

h2 class="example">实例
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# 测试 FFN
ffn = FeedForward(d_model=512, dim_feedforward=2048)
x = torch.randn(32, 100, 512)  # (batch, seq, d_model)

output = ffn(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 4: 使用 tanh 近似

使用 tanh 近似加速计算:

实例

import torch
import torch.nn as nn

# 精确版本
gelu_exact = nn.GELU(approximate='none')

# tanh 近似版本
gelu_approx = nn.GELU(approximate='tanh')

x = torch.randn(1000)

output_exact = gelu_exact(x)
output_approx = gelu_approx(x)

# 计算差异
diff = (output_exact - output_approx).abs().max().item()
print(f"最大差异: {diff:.8f}")

# 性能对比
import time

for _ in range(100):
    _ = gelu_exact(x)

start = time.time()
for _ in range(1000):
    _ = gelu_exact(x)
time_exact = time.time() - start

start = time.time()
for _ in range(1000):
    _ = gelu_approx(x)
time_approx = time.time() - start

print(f"精确版本时间: {time_exact:.4f}s")
print(f"近似版本时间: {time_approx:.4f}s")

激活函数对比

激活函数 特点 适用场景
nn.GELU 平滑、非零负值、Transformer 默认 Transformer、BERT、GPT
nn.ReLU 简单、稀疏激活、死神经元 CNN、通用深度学习
nn.SiLU 平滑、自我门控 MobileNet、EfficientNet

常见问题

Q1: GELU 相比 ReLU 有什么优势?

  • 负值有轻微激活,信息不丢失
  • 梯度更平滑,有助于训练
  • 在 Transformer 中表现更好

Q2: 何时使用近似版本?

当对推理速度有要求且精度要求不严格时,tanh 近似更快。

Q3: GELU 可以用于输出层吗?

通常不用于输出层。分类任务用 Softmax,回归任务用恒等函数。


使用场景

nn.GELU 主要应用场景包括:

  • Transformer 架构: BERT、GPT 等模型
  • 深度神经网络: 需要平滑激活的场合
  • 预训练模型: 现代 NLP 模型

提示:GELU 是当前 NLP 领域最常用的激活函数,是 Transformer 的标配。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册