PyTorch torch.nn.BCEWithLogitsLoss 函数
torch.nn.BCEWithLogitsLoss 是 PyTorch 中的二元交叉熵损失函数(带 Sigmoid)。
它将 Sigmoid 和 BCE 结合,数值更稳定,用于二分类任务。
函数定义
torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
参数:
weight: 手动权重pos_weight: 正类权重,用于类别不平衡
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
criterion = nn.BCEWithLogitsLoss()
# 未归一化的 logits
logits = torch.tensor([2.0, -1.0, 0.5, -3.0])
# 二元标签
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])
loss = criterion(logits, targets)
print("BCE Loss:", loss.item())
# 手动验证:Sigmoid + BCE
sigmoid = torch.sigmoid(logits)
bce = nn.BCELoss()(sigmoid, targets)
print("手动 BCE:", bce.item())
import torch.nn as nn
criterion = nn.BCEWithLogitsLoss()
# 未归一化的 logits
logits = torch.tensor([2.0, -1.0, 0.5, -3.0])
# 二元标签
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])
loss = criterion(logits, targets)
print("BCE Loss:", loss.item())
# 手动验证:Sigmoid + BCE
sigmoid = torch.sigmoid(logits)
bce = nn.BCELoss()(sigmoid, targets)
print("手动 BCE:", bce.item())
示例 2: 类别不平衡
实例
import torch
import torch.nn as nn
# 正类权重:增加正样本的重要性
pos_weight = torch.tensor([5.0])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
logits = torch.randn(10, 1)
targets = torch.zeros(10, 1)
targets[:2] = 1.0 # 正类很少
loss = criterion(logits, targets)
print("加权 BCE Loss:", loss.item())
import torch.nn as nn
# 正类权重:增加正样本的重要性
pos_weight = torch.tensor([5.0])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
logits = torch.randn(10, 1)
targets = torch.zeros(10, 1)
targets[:2] = 1.0 # 正类很少
loss = criterion(logits, targets)
print("加权 BCE Loss:", loss.item())
示例 3: 多标签分类
实例
import torch
import torch.nn as nn
# 多标签二分类
criterion = nn.BCEWithLogitsLoss()
# batch=4, 5个类别,每个可以是 0 或 1
logits = torch.randn(4, 5)
labels = torch.randint(0, 2, (4, 5)).float()
loss = criterion(logits, labels)
print("多标签 Loss:", loss.item())
import torch.nn as nn
# 多标签二分类
criterion = nn.BCEWithLogitsLoss()
# batch=4, 5个类别,每个可以是 0 或 1
logits = torch.randn(4, 5)
labels = torch.randint(0, 2, (4, 5)).float()
loss = criterion(logits, labels)
print("多标签 Loss:", loss.item())
使用场景
- 二分类: 单标签
- 多标签分类: 每个标签独立
- 类别不平衡: 使用 pos_weight
注意:输入是 logits,不需提前 Sigmoid。

PyTorch torch.nn 参考手册