现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.BCEWithLogitsLoss 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.BCEWithLogitsLoss 是 PyTorch 中的二元交叉熵损失函数(带 Sigmoid)。

它将 Sigmoid 和 BCE 结合,数值更稳定,用于二分类任务。

函数定义

torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)

参数:

  • weight: 手动权重
  • pos_weight: 正类权重,用于类别不平衡

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

criterion = nn.BCEWithLogitsLoss()

# 未归一化的 logits
logits = torch.tensor([2.0, -1.0, 0.5, -3.0])
# 二元标签
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])

loss = criterion(logits, targets)
print("BCE Loss:", loss.item())

# 手动验证:Sigmoid + BCE
sigmoid = torch.sigmoid(logits)
bce = nn.BCELoss()(sigmoid, targets)
print("手动 BCE:", bce.item())

示例 2: 类别不平衡

实例

import torch
import torch.nn as nn

# 正类权重:增加正样本的重要性
pos_weight = torch.tensor([5.0])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

logits = torch.randn(10, 1)
targets = torch.zeros(10, 1)
targets[:2] = 1.0  # 正类很少

loss = criterion(logits, targets)
print("加权 BCE Loss:", loss.item())

示例 3: 多标签分类

实例

import torch
import torch.nn as nn

# 多标签二分类
criterion = nn.BCEWithLogitsLoss()

# batch=4, 5个类别,每个可以是 0 或 1
logits = torch.randn(4, 5)
labels = torch.randint(0, 2, (4, 5)).float()

loss = criterion(logits, labels)
print("多标签 Loss:", loss.item())

使用场景

  • 二分类: 单标签
  • 多标签分类: 每个标签独立
  • 类别不平衡: 使用 pos_weight

注意:输入是 logits,不需提前 Sigmoid。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册