PyTorch torch.nn.ELU 函数
torch.nn.ELU 是 PyTorch 中的指数线性单元激活函数。
它对负值使用指数函数,允许负值有非零输出,梯度更平滑。
函数定义
torch.nn.ELU(alpha=1.0, inplace=False)
公式
ELU(x) = x, x > 0 ELU(x) = alpha * (e^x - 1), x <= 0
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
elu = nn.ELU(alpha=1.0)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
output = elu(x)
print("输入:", x.tolist())
print("输出:", output.tolist())
import torch.nn as nn
elu = nn.ELU(alpha=1.0)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
output = elu(x)
print("输入:", x.tolist())
print("输出:", output.tolist())
示例 2: 对比 ReLU
实例
import torch
import torch.nn as nn
import numpy as np
x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0])
x_t = torch.tensor(x)
print("x ELU ReLU")
for i in range(0, 6, 1):
print(f"{x[i]:6.1f} {nn.ELU()(x_t[i:i+1]).item():9.4f} {nn.ReLU()(x_t[i:i+1]).item():9.4f}")
import torch.nn as nn
import numpy as np
x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0])
x_t = torch.tensor(x)
print("x ELU ReLU")
for i in range(0, 6, 1):
print(f"{x[i]:6.1f} {nn.ELU()(x_t[i:i+1]).item():9.4f} {nn.ReLU()(x_t[i:i+1]).item():9.4f}")
示例 3: alpha 参数
实例
import torch
import torch.nn as nn
# 不同 alpha 值
x = torch.tensor([-1.0])
for alpha in [0.5, 1.0, 1.5, 2.0]:
out = nn.ELU(alpha=alpha)(x)
print(f"alpha={alpha}: {out.item():.4f}")
import torch.nn as nn
# 不同 alpha 值
x = torch.tensor([-1.0])
for alpha in [0.5, 1.0, 1.5, 2.0]:
out = nn.ELU(alpha=alpha)(x)
print(f"alpha={alpha}: {out.item():.4f}")
使用场景
- 自编码器
- 需要负值输出
- 平滑梯度
提示:ELU 计算比 ReLU 慢,但收敛更快。

PyTorch torch.nn 参考手册