现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.ELU 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.ELU 是 PyTorch 中的指数线性单元激活函数。

它对负值使用指数函数,允许负值有非零输出,梯度更平滑。

函数定义

torch.nn.ELU(alpha=1.0, inplace=False)

公式

ELU(x) = x, x > 0
ELU(x) = alpha * (e^x - 1), x <= 0

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

elu = nn.ELU(alpha=1.0)

x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
output = elu(x)

print("输入:", x.tolist())
print("输出:", output.tolist())

示例 2: 对比 ReLU

实例

import torch
import torch.nn as nn
import numpy as np

x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0])
x_t = torch.tensor(x)

print("x        ELU       ReLU")
for i in range(0, 6, 1):
    print(f"{x[i]:6.1f} {nn.ELU()(x_t[i:i+1]).item():9.4f} {nn.ReLU()(x_t[i:i+1]).item():9.4f}")

示例 3: alpha 参数

实例

import torch
import torch.nn as nn

# 不同 alpha 值
x = torch.tensor([-1.0])

for alpha in [0.5, 1.0, 1.5, 2.0]:
    out = nn.ELU(alpha=alpha)(x)
    print(f"alpha={alpha}: {out.item():.4f}")

使用场景

  • 自编码器
  • 需要负值输出
  • 平滑梯度

提示:ELU 计算比 ReLU 慢,但收敛更快。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册