现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.LogSoftmax 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.LogSoftmax 是 PyTorch 中的 Log Softmax 激活函数。

它是 Softmax 的对数形式,数值更稳定,常与 NLLLoss 配合使用。

函数定义

torch.nn.LogSoftmax(dim=None)

公式

LogSoftmax(x_i) = log(exp(x_i) / sum(exp(x_j)))

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

log_softmax = nn.LogSoftmax(dim=1)

logits = torch.tensor([[2.0, 1.0, 0.1]])
log_probs = log_softmax(logits)

print("Logits:", logits.tolist())
print("Log Softmax:", log_probs.tolist())
print("exp后:", log_probs.exp().tolist())

示例 2: 与 NLLLoss 配合

实例

import torch
import torch.nn as nn

# 分类任务
logits = torch.randn(4, 10)
targets = torch.tensor([2, 5, 1, 7])

# LogSoftmax + NLLLoss = CrossEntropyLoss
loss = nn.NLLLoss()(nn.LogSoftmax(dim=1)(logits), targets)
print("NLL Loss:", loss.item())

# 等价于
loss2 = nn.CrossEntropyLoss()(logits, targets)
print("CrossEntropyLoss:", loss2.item())

示例 3: 数值稳定性

实例

import torch
import torch.nn as nn

# 大值 logits
logits = torch.tensor([[1000, 1001, 1002]])

# Softmax 可能溢出
try:
    sm = nn.Softmax(dim=1)(logits)
    print("Softmax:", sm)
except:
    print("Softmax 溢出")

# LogSoftmax 数值稳定
lsm = nn.LogSoftmax(dim=1)(logits)
print("LogSoftmax:", lsm)

使用场景

  • 分类任务: 配合 NLLLoss
  • 数值稳定: 大 logit 值

提示:LogSoftmax + NLLLoss 等价于 CrossEntropyLoss。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册