现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Softmax 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Softmax 是 PyTorch 中的 Softmax 激活函数。

它将输入转换为概率分布,所有输出之和为 1。

函数定义

torch.nn.Softmax(dim=None)

参数:

  • dim: 进行 Softmax 的维度

公式

Softmax(x_i) = exp(x_i) / sum(exp(x_j))

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

softmax = nn.Softmax(dim=1)

logits = torch.tensor([[2.0, 1.0, 0.1], [1.0, 2.0, 0.5]])
probs = softmax(logits)

print("Logits:", logits.tolist())
print("概率:", probs.tolist())
print("每行和:", probs.sum(dim=1).tolist())

示例 2: dim 参数

实例

import torch
import torch.nn as nn

# 3D 输入
x = torch.randn(2, 3, 4)

# 按不同维度 softmax
print("dim=1:", nn.Softmax(dim=1)(x).sum(dim=1)[:1])
print("dim=2:", nn.Softmax(dim=2)(x).sum(dim=2)[:1])
print("dim=-1:", nn.Softmax(dim=-1)(x).sum(dim=-1)[:1])

示例 3: 分类输出

实例

import torch
import torch.nn as nn

model = nn.Linear(784, 10)

# 输出 logits
logits = model(torch.randn(4, 784))

# 转换为概率
probs = nn.Softmax(dim=1)(logits)

print("预测类别:", probs.argmax(dim=1).tolist())
print("最高概率:", probs.max(dim=1).values.tolist())

使用场景

  • 多分类输出: 概率分布
  • 注意力机制
  • 概率模型

注意:Softmax 后的值都是正数且和为 1。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册