PyTorch torch.nn.NLLLoss 函数
torch.nn.NLLLoss 是 PyTorch 中的负对数似然损失。
它用于多分类任务,需要配合 LogSoftmax 使用。
函数定义
torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 先 LogSoftmax 再 NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll = nn.NLLLoss()
# 模型输出 (logits)
logits = torch.randn(4, 10)
# 经过 log softmax
log_probs = log_softmax(logits)
# 真实标签
targets = torch.tensor([2, 5, 1, 7])
loss = nll(log_probs, targets)
print("NLL Loss:", loss.item())
import torch.nn as nn
# 先 LogSoftmax 再 NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll = nn.NLLLoss()
# 模型输出 (logits)
logits = torch.randn(4, 10)
# 经过 log softmax
log_probs = log_softmax(logits)
# 真实标签
targets = torch.tensor([2, 5, 1, 7])
loss = nll(log_probs, targets)
print("NLL Loss:", loss.item())
示例 2: 等价于 CrossEntropyLoss
实例
import torch
import torch.nn as nn
logits = torch.randn(4, 10)
targets = torch.tensor([2, 5, 1, 7])
# 方式1:直接用 CrossEntropyLoss
loss1 = nn.CrossEntropyLoss()(logits, targets)
# 方式2:LogSoftmax + NLLLoss
loss2 = nn.NLLLoss()(nn.LogSoftmax(dim=1)(logits), targets)
print("CrossEntropyLoss:", loss1.item())
print("LogSoftmax + NLLLoss:", loss2.item())
print("结果相同:", abs(loss1 - loss2) < 1e-5)
import torch.nn as nn
logits = torch.randn(4, 10)
targets = torch.tensor([2, 5, 1, 7])
# 方式1:直接用 CrossEntropyLoss
loss1 = nn.CrossEntropyLoss()(logits, targets)
# 方式2:LogSoftmax + NLLLoss
loss2 = nn.NLLLoss()(nn.LogSoftmax(dim=1)(logits), targets)
print("CrossEntropyLoss:", loss1.item())
print("LogSoftmax + NLLLoss:", loss2.item())
print("结果相同:", abs(loss1 - loss2) < 1e-5)
使用场景
- 多分类: 配合 LogSoftmax
- 自定义损失: 特殊需求
注意:CrossEntropyLoss 已包含 LogSoftmax,通常直接使用。

PyTorch torch.nn 参考手册