现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.NLLLoss 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.NLLLoss 是 PyTorch 中的负对数似然损失。

它用于多分类任务,需要配合 LogSoftmax 使用。

函数定义

torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 先 LogSoftmax 再 NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll = nn.NLLLoss()

# 模型输出 (logits)
logits = torch.randn(4, 10)
# 经过 log softmax
log_probs = log_softmax(logits)
# 真实标签
targets = torch.tensor([2, 5, 1, 7])

loss = nll(log_probs, targets)
print("NLL Loss:", loss.item())

示例 2: 等价于 CrossEntropyLoss

实例

import torch
import torch.nn as nn

logits = torch.randn(4, 10)
targets = torch.tensor([2, 5, 1, 7])

# 方式1:直接用 CrossEntropyLoss
loss1 = nn.CrossEntropyLoss()(logits, targets)

# 方式2:LogSoftmax + NLLLoss
loss2 = nn.NLLLoss()(nn.LogSoftmax(dim=1)(logits), targets)

print("CrossEntropyLoss:", loss1.item())
print("LogSoftmax + NLLLoss:", loss2.item())
print("结果相同:", abs(loss1 - loss2) < 1e-5)

使用场景

  • 多分类: 配合 LogSoftmax
  • 自定义损失: 特殊需求

注意:CrossEntropyLoss 已包含 LogSoftmax,通常直接使用。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册