PyTorch torch.nn.LogSoftmax 函数
torch.nn.LogSoftmax 是 PyTorch 中的 Log Softmax 激活函数。
它是 Softmax 的对数形式,数值更稳定,常与 NLLLoss 配合使用。
函数定义
torch.nn.LogSoftmax(dim=None)
公式
LogSoftmax(x_i) = log(exp(x_i) / sum(exp(x_j)))
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
log_softmax = nn.LogSoftmax(dim=1)
logits = torch.tensor([[2.0, 1.0, 0.1]])
log_probs = log_softmax(logits)
print("Logits:", logits.tolist())
print("Log Softmax:", log_probs.tolist())
print("exp后:", log_probs.exp().tolist())
import torch.nn as nn
log_softmax = nn.LogSoftmax(dim=1)
logits = torch.tensor([[2.0, 1.0, 0.1]])
log_probs = log_softmax(logits)
print("Logits:", logits.tolist())
print("Log Softmax:", log_probs.tolist())
print("exp后:", log_probs.exp().tolist())
示例 2: 与 NLLLoss 配合
实例
import torch
import torch.nn as nn
# 分类任务
logits = torch.randn(4, 10)
targets = torch.tensor([2, 5, 1, 7])
# LogSoftmax + NLLLoss = CrossEntropyLoss
loss = nn.NLLLoss()(nn.LogSoftmax(dim=1)(logits), targets)
print("NLL Loss:", loss.item())
# 等价于
loss2 = nn.CrossEntropyLoss()(logits, targets)
print("CrossEntropyLoss:", loss2.item())
import torch.nn as nn
# 分类任务
logits = torch.randn(4, 10)
targets = torch.tensor([2, 5, 1, 7])
# LogSoftmax + NLLLoss = CrossEntropyLoss
loss = nn.NLLLoss()(nn.LogSoftmax(dim=1)(logits), targets)
print("NLL Loss:", loss.item())
# 等价于
loss2 = nn.CrossEntropyLoss()(logits, targets)
print("CrossEntropyLoss:", loss2.item())
示例 3: 数值稳定性
实例
import torch
import torch.nn as nn
# 大值 logits
logits = torch.tensor([[1000, 1001, 1002]])
# Softmax 可能溢出
try:
sm = nn.Softmax(dim=1)(logits)
print("Softmax:", sm)
except:
print("Softmax 溢出")
# LogSoftmax 数值稳定
lsm = nn.LogSoftmax(dim=1)(logits)
print("LogSoftmax:", lsm)
import torch.nn as nn
# 大值 logits
logits = torch.tensor([[1000, 1001, 1002]])
# Softmax 可能溢出
try:
sm = nn.Softmax(dim=1)(logits)
print("Softmax:", sm)
except:
print("Softmax 溢出")
# LogSoftmax 数值稳定
lsm = nn.LogSoftmax(dim=1)(logits)
print("LogSoftmax:", lsm)
使用场景
- 分类任务: 配合 NLLLoss
- 数值稳定: 大 logit 值
提示:LogSoftmax + NLLLoss 等价于 CrossEntropyLoss。

PyTorch torch.nn 参考手册