PyTorch torch.nn.GELU 函数
torch.nn.GELU 是 PyTorch 中的高斯误差线性单元(Gaussian Error Linear Unit)激活函数。
它是 Transformer 架构的默认激活函数,相比 ReLU 具有更好的性能和平滑的梯度。
函数定义
torch.nn.GELU(approximate='none')
参数说明:
approximate(str): 近似算法。可选'none'、'tanh'。默认为'none'。
数学原理
GELU 的数学公式:
GELU(x) = x * Φ(x)
其中 Φ(x) 是标准正态分布的累积分布函数(CDF)。
使用 tanh 近似时:
GELU(x) ≈ 0.5x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
使用示例
示例 1: 基本用法
创建并使用 GELU 激活:
实例
import torch
import torch.nn as nn
# 创建 GELU 激活层
gelu = nn.GELU()
# 测试输入
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
# 前向传播
output = gelu(x)
print("输入:", x.tolist())
print("输出:", output.tolist())
print("n观察:负值有轻微激活(非零),正值保持增长")
import torch.nn as nn
# 创建 GELU 激活层
gelu = nn.GELU()
# 测试输入
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
# 前向传播
output = gelu(x)
print("输入:", x.tolist())
print("输出:", output.tolist())
print("n观察:负值有轻微激活(非零),正值保持增长")
示例 2: 对比不同激活函数
比较 GELU、ReLU、Sigmoid:
实例
import torch
import torch.nn as nn
x = torch.linspace(-4, 4, 21)
# 不同激活函数
gelu = nn.GELU()
relu = nn.ReLU()
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()
print("x GELU ReLU Sigmoid Tanh")
print("-" * 50)
for i in range(0, 21, 3):
xi = x[i:i+3]
print(f"{xi[0]:6.2f} {gelu(xi)[0]:8.4f} {relu(xi)[0]:8.4f} {sigmoid(xi)[0]:8.4f} {tanh(xi)[0]:8.4f}")
import torch.nn as nn
x = torch.linspace(-4, 4, 21)
# 不同激活函数
gelu = nn.GELU()
relu = nn.ReLU()
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()
print("x GELU ReLU Sigmoid Tanh")
print("-" * 50)
for i in range(0, 21, 3):
xi = x[i:i+3]
print(f"{xi[0]:6.2f} {gelu(xi)[0]:8.4f} {relu(xi)[0]:8.4f} {sigmoid(xi)[0]:8.4f} {tanh(xi)[0]:8.4f}")
示例 3: 在 Transformer 中使用
典型的 Transformer FFN 层:
h2 class="example">实例
import torch
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
self.linear2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
return x
# 测试 FFN
ffn = FeedForward(d_model=512, dim_feedforward=2048)
x = torch.randn(32, 100, 512) # (batch, seq, d_model)
output = ffn(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
self.linear2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
return x
# 测试 FFN
ffn = FeedForward(d_model=512, dim_feedforward=2048)
x = torch.randn(32, 100, 512) # (batch, seq, d_model)
output = ffn(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
示例 4: 使用 tanh 近似
使用 tanh 近似加速计算:
实例
import torch
import torch.nn as nn
# 精确版本
gelu_exact = nn.GELU(approximate='none')
# tanh 近似版本
gelu_approx = nn.GELU(approximate='tanh')
x = torch.randn(1000)
output_exact = gelu_exact(x)
output_approx = gelu_approx(x)
# 计算差异
diff = (output_exact - output_approx).abs().max().item()
print(f"最大差异: {diff:.8f}")
# 性能对比
import time
for _ in range(100):
_ = gelu_exact(x)
start = time.time()
for _ in range(1000):
_ = gelu_exact(x)
time_exact = time.time() - start
start = time.time()
for _ in range(1000):
_ = gelu_approx(x)
time_approx = time.time() - start
print(f"精确版本时间: {time_exact:.4f}s")
print(f"近似版本时间: {time_approx:.4f}s")
import torch.nn as nn
# 精确版本
gelu_exact = nn.GELU(approximate='none')
# tanh 近似版本
gelu_approx = nn.GELU(approximate='tanh')
x = torch.randn(1000)
output_exact = gelu_exact(x)
output_approx = gelu_approx(x)
# 计算差异
diff = (output_exact - output_approx).abs().max().item()
print(f"最大差异: {diff:.8f}")
# 性能对比
import time
for _ in range(100):
_ = gelu_exact(x)
start = time.time()
for _ in range(1000):
_ = gelu_exact(x)
time_exact = time.time() - start
start = time.time()
for _ in range(1000):
_ = gelu_approx(x)
time_approx = time.time() - start
print(f"精确版本时间: {time_exact:.4f}s")
print(f"近似版本时间: {time_approx:.4f}s")
激活函数对比
| 激活函数 | 特点 | 适用场景 |
|---|---|---|
nn.GELU |
平滑、非零负值、Transformer 默认 | Transformer、BERT、GPT |
nn.ReLU |
简单、稀疏激活、死神经元 | CNN、通用深度学习 |
nn.SiLU |
平滑、自我门控 | MobileNet、EfficientNet |
常见问题
Q1: GELU 相比 ReLU 有什么优势?
- 负值有轻微激活,信息不丢失
- 梯度更平滑,有助于训练
- 在 Transformer 中表现更好
Q2: 何时使用近似版本?
当对推理速度有要求且精度要求不严格时,tanh 近似更快。
Q3: GELU 可以用于输出层吗?
通常不用于输出层。分类任务用 Softmax,回归任务用恒等函数。
使用场景
nn.GELU 主要应用场景包括:
- Transformer 架构: BERT、GPT 等模型
- 深度神经网络: 需要平滑激活的场合
- 预训练模型: 现代 NLP 模型
提示:GELU 是当前 NLP 领域最常用的激活函数,是 Transformer 的标配。

PyTorch torch.nn 参考手册