PyTorch torch.nn.CrossEntropyLoss 函数
torch.nn.CrossEntropyLoss 是 PyTorch 中用于多分类的损失函数。
它结合了 nn.LogSoftmax 和 nn.NLLLoss,常用于图像分类、文本分类等任务。
函数定义
torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
参数说明:
weight(Tensor): 为每个类赋予不同的权重,用于类别不平衡的情况。ignore_index(int): 忽略指定索引的损失计算。默认为 -100。reduction(str): 损失聚合方式。可选'mean'、'sum'、'none'。默认为'mean'。label_smoothing(float): 标签平滑参数,取值 0 到 1 之间。默认为 0。
数学原理
交叉熵损失的公式:
Loss = -log(exp(y_true) / sum(exp(y_i)))
即正确类别的预测概率越大,损失越小。
使用示例
示例 1: 基本用法
创建并使用交叉熵损失:
实例
import torch
import torch.nn as nn
# 创建损失函数
criterion = nn.CrossEntropyLoss()
# 模型输出的 logits(未归一化)
# 形状: (batch_size, num_classes)
outputs = torch.randn(4, 10)
# 真实标签
labels = torch.tensor([2, 5, 1, 7])
# 计算损失
loss = criterion(outputs, labels)
print("模型输出 (logits):", outputs[0].tolist())
print("真实标签:", labels[0].item())
print("交叉熵损失:", loss.item())
import torch.nn as nn
# 创建损失函数
criterion = nn.CrossEntropyLoss()
# 模型输出的 logits(未归一化)
# 形状: (batch_size, num_classes)
outputs = torch.randn(4, 10)
# 真实标签
labels = torch.tensor([2, 5, 1, 7])
# 计算损失
loss = criterion(outputs, labels)
print("模型输出 (logits):", outputs[0].tolist())
print("真实标签:", labels[0].item())
print("交叉熵损失:", loss.item())
示例 2: 类别权重
处理类别不平衡:
实例
import torch
import torch.nn as nn
# 类别权重:给少数类更高的权重
# 假设 10 个类别,第 3 类和第 7 类更重要
weight = torch.ones(10)
weight[3] = 2.0
weight[7] = 2.0
criterion_weighted = nn.CrossEntropyLoss(weight=weight)
outputs = torch.randn(4, 10)
labels = torch.tensor([2, 3, 7, 5])
loss = criterion_weighted(outputs, labels)
print("加权交叉熵损失:", loss.item())
import torch.nn as nn
# 类别权重:给少数类更高的权重
# 假设 10 个类别,第 3 类和第 7 类更重要
weight = torch.ones(10)
weight[3] = 2.0
weight[7] = 2.0
criterion_weighted = nn.CrossEntropyLoss(weight=weight)
outputs = torch.randn(4, 10)
labels = torch.tensor([2, 3, 7, 5])
loss = criterion_weighted(outputs, labels)
print("加权交叉熵损失:", loss.item())
示例 3: 标签平滑
使用标签平滑防止过拟合:
实例
import torch
import torch.nn as nn
# 标签平滑:0.1 表示将 10% 的概率均匀分配给其他类别
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)
outputs = torch.randn(4, 10)
labels = torch.tensor([2, 5, 1, 7])
loss = criterion_smooth(outputs, labels)
print("带标签平滑的损失:", loss.item())
# 对比:不带标签平滑
criterion = nn.CrossEntropyLoss()
loss_no_smooth = criterion(outputs, labels)
print("不带标签平滑的损失:", loss_no_smooth.item())
import torch.nn as nn
# 标签平滑:0.1 表示将 10% 的概率均匀分配给其他类别
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)
outputs = torch.randn(4, 10)
labels = torch.tensor([2, 5, 1, 7])
loss = criterion_smooth(outputs, labels)
print("带标签平滑的损失:", loss.item())
# 对比:不带标签平滑
criterion = nn.CrossEntropyLoss()
loss_no_smooth = criterion(outputs, labels)
print("不带标签平滑的损失:", loss_no_smooth.item())
示例 4: 完整的分类训练流程
完整的模型训练示例:
实例
import torch
import torch.nn as nn
import torch.optim as optim
# 简单的分类模型
class Classifier(nn.Module):
def __init__(self, input_dim=784, num_classes=10):
super(Classifier, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.fc(x)
# 初始化模型和损失
model = Classifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 模拟训练数据
batch_size = 32
x = torch.randn(batch_size, 784) # 输入
y = torch.randint(0, 10, (batch_size,)) # 标签
# 前向传播
model.train()
outputs = model(x)
loss = criterion(outputs, y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("批次损失:", loss.item())
# 预测
model.eval()
with torch.no_grad():
outputs = model(x)
predictions = outputs.argmax(dim=1)
accuracy = (predictions == y).float().mean()
print("预测准确率:", accuracy.item())
import torch.nn as nn
import torch.optim as optim
# 简单的分类模型
class Classifier(nn.Module):
def __init__(self, input_dim=784, num_classes=10):
super(Classifier, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.fc(x)
# 初始化模型和损失
model = Classifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 模拟训练数据
batch_size = 32
x = torch.randn(batch_size, 784) # 输入
y = torch.randint(0, 10, (batch_size,)) # 标签
# 前向传播
model.train()
outputs = model(x)
loss = criterion(outputs, y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("批次损失:", loss.item())
# 预测
model.eval()
with torch.no_grad():
outputs = model(x)
predictions = outputs.argmax(dim=1)
accuracy = (predictions == y).float().mean()
print("预测准确率:", accuracy.item())
示例 5: ignore_index 的使用
忽略特定标签:
实例
import torch
import torch.nn as nn
# 忽略 label=-100 的样本
criterion = nn.CrossEntropyLoss(ignore_index=-100)
outputs = torch.randn(5, 10)
# 某些样本的标签为 -100,表示忽略
labels = torch.tensor([2, -100, 5, -100, 7])
loss = criterion(outputs, labels)
print("忽略特殊标签后的损失:", loss.item())
import torch.nn as nn
# 忽略 label=-100 的样本
criterion = nn.CrossEntropyLoss(ignore_index=-100)
outputs = torch.randn(5, 10)
# 某些样本的标签为 -100,表示忽略
labels = torch.tensor([2, -100, 5, -100, 7])
loss = criterion(outputs, labels)
print("忽略特殊标签后的损失:", loss.item())
示例 6: 不同的 reduction 方式
控制损失聚合方式:
实例
import torch
import torch.nn as nn
outputs = torch.randn(4, 10)
labels = torch.tensor([2, 5, 1, 7])
# mean: 返回平均损失
loss_mean = nn.CrossEntropyLoss(reduction='mean')(outputs, labels)
print("mean:", loss_mean.item())
# sum: 返回总和
loss_sum = nn.CrossEntropyLoss(reduction='sum')(outputs, labels)
print("sum:", loss_sum.item())
# none: 返回每个样本的损失
loss_none = nn.CrossEntropyLoss(reduction='none')(outputs, labels)
print("none:", loss_none.tolist())
import torch.nn as nn
outputs = torch.randn(4, 10)
labels = torch.tensor([2, 5, 1, 7])
# mean: 返回平均损失
loss_mean = nn.CrossEntropyLoss(reduction='mean')(outputs, labels)
print("mean:", loss_mean.item())
# sum: 返回总和
loss_sum = nn.CrossEntropyLoss(reduction='sum')(outputs, labels)
print("sum:", loss_sum.item())
# none: 返回每个样本的损失
loss_none = nn.CrossEntropyLoss(reduction='none')(outputs, labels)
print("none:", loss_none.tolist())
常见问题
Q1: CrossEntropyLoss 和 NLLLoss 有什么区别?
CrossEntropyLoss = LogSoftmax + NLLLoss。它已经内置了 softmax,不需要手动添加。
Q2: 为什么模型的输出不用 softmax?
CrossEntropyLoss 内部会自动计算 softmax,直接使用 logits 可以提高数值稳定性。
Q3: 标签平滑适合什么场景?
标签平滑适合分类数量较多的情况,可以提高模型的泛化能力。
使用场景
nn.CrossEntropyLoss 主要应用场景包括:
- 图像分类: 如 CIFAR-10、ImageNet
- 文本分类: 情感分析、主题分类
- 多分类任务: 任何类别数大于 2 的分类任务
注意:标签应该是类别索引(0 到 num_classes-1),不是 one-hot 编码。

PyTorch torch.nn 参考手册