现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.CrossEntropyLoss 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.CrossEntropyLoss 是 PyTorch 中用于多分类的损失函数。

它结合了 nn.LogSoftmax 和 nn.NLLLoss,常用于图像分类、文本分类等任务。

函数定义

torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)

参数说明:

  • weight (Tensor): 为每个类赋予不同的权重,用于类别不平衡的情况。
  • ignore_index (int): 忽略指定索引的损失计算。默认为 -100。
  • reduction (str): 损失聚合方式。可选 'mean''sum''none'。默认为 'mean'
  • label_smoothing (float): 标签平滑参数,取值 0 到 1 之间。默认为 0。

数学原理

交叉熵损失的公式:

Loss = -log(exp(y_true) / sum(exp(y_i)))

即正确类别的预测概率越大,损失越小。


使用示例

示例 1: 基本用法

创建并使用交叉熵损失:

实例

import torch
import torch.nn as nn

# 创建损失函数
criterion = nn.CrossEntropyLoss()

# 模型输出的 logits(未归一化)
# 形状: (batch_size, num_classes)
outputs = torch.randn(4, 10)

# 真实标签
labels = torch.tensor([2, 5, 1, 7])

# 计算损失
loss = criterion(outputs, labels)

print("模型输出 (logits):", outputs[0].tolist())
print("真实标签:", labels[0].item())
print("交叉熵损失:", loss.item())

示例 2: 类别权重

处理类别不平衡:

实例

import torch
import torch.nn as nn

# 类别权重:给少数类更高的权重
# 假设 10 个类别,第 3 类和第 7 类更重要
weight = torch.ones(10)
weight[3] = 2.0
weight[7] = 2.0

criterion_weighted = nn.CrossEntropyLoss(weight=weight)

outputs = torch.randn(4, 10)
labels = torch.tensor([2, 3, 7, 5])

loss = criterion_weighted(outputs, labels)
print("加权交叉熵损失:", loss.item())

示例 3: 标签平滑

使用标签平滑防止过拟合:

实例

import torch
import torch.nn as nn

# 标签平滑:0.1 表示将 10% 的概率均匀分配给其他类别
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

outputs = torch.randn(4, 10)
labels = torch.tensor([2, 5, 1, 7])

loss = criterion_smooth(outputs, labels)
print("带标签平滑的损失:", loss.item())

# 对比:不带标签平滑
criterion = nn.CrossEntropyLoss()
loss_no_smooth = criterion(outputs, labels)
print("不带标签平滑的损失:", loss_no_smooth.item())

示例 4: 完整的分类训练流程

完整的模型训练示例:

实例

import torch
import torch.nn as nn
import torch.optim as optim

# 简单的分类模型
class Classifier(nn.Module):
    def __init__(self, input_dim=784, num_classes=10):
        super(Classifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# 初始化模型和损失
model = Classifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模拟训练数据
batch_size = 32
x = torch.randn(batch_size, 784)  # 输入
y = torch.randint(0, 10, (batch_size,))  # 标签

# 前向传播
model.train()
outputs = model(x)
loss = criterion(outputs, y)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

print("批次损失:", loss.item())

# 预测
model.eval()
with torch.no_grad():
    outputs = model(x)
    predictions = outputs.argmax(dim=1)
    accuracy = (predictions == y).float().mean()

print("预测准确率:", accuracy.item())

示例 5: ignore_index 的使用

忽略特定标签:

实例

import torch
import torch.nn as nn

# 忽略 label=-100 的样本
criterion = nn.CrossEntropyLoss(ignore_index=-100)

outputs = torch.randn(5, 10)
# 某些样本的标签为 -100,表示忽略
labels = torch.tensor([2, -100, 5, -100, 7])

loss = criterion(outputs, labels)
print("忽略特殊标签后的损失:", loss.item())

示例 6: 不同的 reduction 方式

控制损失聚合方式:

实例

import torch
import torch.nn as nn

outputs = torch.randn(4, 10)
labels = torch.tensor([2, 5, 1, 7])

# mean: 返回平均损失
loss_mean = nn.CrossEntropyLoss(reduction='mean')(outputs, labels)
print("mean:", loss_mean.item())

# sum: 返回总和
loss_sum = nn.CrossEntropyLoss(reduction='sum')(outputs, labels)
print("sum:", loss_sum.item())

# none: 返回每个样本的损失
loss_none = nn.CrossEntropyLoss(reduction='none')(outputs, labels)
print("none:", loss_none.tolist())

常见问题

Q1: CrossEntropyLoss 和 NLLLoss 有什么区别?

CrossEntropyLoss = LogSoftmax + NLLLoss。它已经内置了 softmax,不需要手动添加。

Q2: 为什么模型的输出不用 softmax?

CrossEntropyLoss 内部会自动计算 softmax,直接使用 logits 可以提高数值稳定性。

Q3: 标签平滑适合什么场景?

标签平滑适合分类数量较多的情况,可以提高模型的泛化能力。


使用场景

nn.CrossEntropyLoss 主要应用场景包括:

  • 图像分类: 如 CIFAR-10、ImageNet
  • 文本分类: 情感分析、主题分类
  • 多分类任务: 任何类别数大于 2 的分类任务

注意:标签应该是类别索引(0 到 num_classes-1),不是 one-hot 编码。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册