PyTorch torch.nn.Softmax 函数
torch.nn.Softmax 是 PyTorch 中的 Softmax 激活函数。
它将输入转换为概率分布,所有输出之和为 1。
函数定义
torch.nn.Softmax(dim=None)
参数:
dim: 进行 Softmax 的维度
公式
Softmax(x_i) = exp(x_i) / sum(exp(x_j))
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
softmax = nn.Softmax(dim=1)
logits = torch.tensor([[2.0, 1.0, 0.1], [1.0, 2.0, 0.5]])
probs = softmax(logits)
print("Logits:", logits.tolist())
print("概率:", probs.tolist())
print("每行和:", probs.sum(dim=1).tolist())
import torch.nn as nn
softmax = nn.Softmax(dim=1)
logits = torch.tensor([[2.0, 1.0, 0.1], [1.0, 2.0, 0.5]])
probs = softmax(logits)
print("Logits:", logits.tolist())
print("概率:", probs.tolist())
print("每行和:", probs.sum(dim=1).tolist())
示例 2: dim 参数
实例
import torch
import torch.nn as nn
# 3D 输入
x = torch.randn(2, 3, 4)
# 按不同维度 softmax
print("dim=1:", nn.Softmax(dim=1)(x).sum(dim=1)[:1])
print("dim=2:", nn.Softmax(dim=2)(x).sum(dim=2)[:1])
print("dim=-1:", nn.Softmax(dim=-1)(x).sum(dim=-1)[:1])
import torch.nn as nn
# 3D 输入
x = torch.randn(2, 3, 4)
# 按不同维度 softmax
print("dim=1:", nn.Softmax(dim=1)(x).sum(dim=1)[:1])
print("dim=2:", nn.Softmax(dim=2)(x).sum(dim=2)[:1])
print("dim=-1:", nn.Softmax(dim=-1)(x).sum(dim=-1)[:1])
示例 3: 分类输出
实例
import torch
import torch.nn as nn
model = nn.Linear(784, 10)
# 输出 logits
logits = model(torch.randn(4, 784))
# 转换为概率
probs = nn.Softmax(dim=1)(logits)
print("预测类别:", probs.argmax(dim=1).tolist())
print("最高概率:", probs.max(dim=1).values.tolist())
import torch.nn as nn
model = nn.Linear(784, 10)
# 输出 logits
logits = model(torch.randn(4, 784))
# 转换为概率
probs = nn.Softmax(dim=1)(logits)
print("预测类别:", probs.argmax(dim=1).tolist())
print("最高概率:", probs.max(dim=1).values.tolist())
使用场景
- 多分类输出: 概率分布
- 注意力机制
- 概率模型
注意:Softmax 后的值都是正数且和为 1。

PyTorch torch.nn 参考手册