现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.MSELoss 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.MSELoss 是 PyTorch 中的均方误差损失函数。

它计算预测值与目标值之间差值的平方均值,常用于回归任务。

函数定义

torch.nn.MSELoss(reduction='mean')

参数说明:

  • reduction (str): 损失聚合方式。可选 'mean''sum''none'。默认为 'mean'

数学原理

MSE 损失公式:

MSE = (1/n) * Σ(y_pred - y_true)²

使用示例

示例 1: 基本用法

计算回归损失:

实例

import torch
import torch.nn as nn

# 创建 MSE 损失
criterion = nn.MSELoss()

# 预测值和真实值
predictions = torch.tensor([3.0, 4.0, 5.0])
targets = torch.tensor([2.8, 4.2, 4.9])

# 计算损失
loss = criterion(predictions, targets)

print("预测值:", predictions.tolist())
print("目标值:", targets.tolist())
print("MSE 损失:", loss.item())

# 手动验证
manual_mse = ((predictions - targets) ** 2).mean()
print("手动计算:", manual_mse.item())

示例 2: 完整的回归训练

回归任务训练流程:

实例

import torch
import torch.nn as nn
import torch.optim as optim

# 简单回归模型
class RegressionNet(nn.Module):
    def __init__(self):
        super(RegressionNet, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = RegressionNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练数据
X = torch.randn(100, 10)
y = torch.randn(100, 1)

# 训练步骤
model.train()
optimizer.zero_grad()
predictions = model(X)
loss = criterion(predictions, y)
loss.backward()
optimizer.step()

print("训练损失:", loss.item())

示例 3: reduction 参数

不同聚合方式:

实例

import torch
import torch.nn as nn

pred = torch.tensor([3.0, 4.0, 5.0, 6.0])
target = torch.tensor([3.5, 3.8, 5.2, 5.5])

# mean: 平均损失
loss_mean = nn.MSELoss(reduction='mean')(pred, target)

# sum: 总损失
loss_sum = nn.MSELoss(reduction='sum')(pred, target)

# none: 不聚合
loss_none = nn.MSELoss(reduction='none')(pred, target)

print("Mean:", loss_mean.item())
print("Sum:", loss_sum.item())
print("None:", loss_none.tolist())

示例 4: 处理多维输出

多维回归:

实例

import torch
import torch.nn as nn

# 多维输出
pred = torch.randn(4, 3)  # batch=4, 3个输出
target = torch.randn(4, 3)

criterion = nn.MSELoss()
loss = criterion(pred, target)

print("预测形状:", pred.shape)
print("目标形状:", target.shape)
print("MSE 损失:", loss.item())

常见问题

Q1: MSE 和 MAE 哪个好?

  • MSE: 对异常值敏感,梯度稳定
  • MAE (L1): 对异常值鲁棒,梯度不稳定

Q2: 输出有负值能用 MSE 吗?

可以,MSE 不限制输出范围。

Q3: 分类任务用什么损失?

分类任务用 CrossEntropyLoss。


使用场景

nn.MSELoss 主要应用场景包括:

  • 回归任务: 房价预测、数值估计
  • 连续值预测: 目标追踪
  • 生成模型: VAE、GAN 的某些损失

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册