现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.multinomial 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.multinomial 是 PyTorch 中用于多项式采样的函数,从输入张量的每一行中按照概率分布进行采样。

函数定义

torch.multinomial(input, num_samples, replacement=False, generator=None)

参数说明

  • input - 概率分布张量,每行是一个概率分布,所有概率之和应为1
  • num_samples - 采样数量
  • replacement - 是否允许放回采样(默认 False)
  • generator - 随机数生成器(可选)

使用示例

实例

import torch

# 定义概率分布
weights = torch.tensor([[0.0, 1.0],   # 第二个类别概率为1
                        [0.5, 0.5],   # 两个类别概率相等
                        [0.2, 0.3, 0.5]])  # 三个类别

# 不放回采样
result = torch.multinomial(weights, num_samples=2, replacement=False)
print("概率分布:")
print(weights)
print("不放回采样结果 (每行采样2个):")
print(result)

# 放回采样
result_with_replacement = torch.multinomial(weights, num_samples=5, replacement=True)
print("放回采样结果 (每行采样5个):")
print(result_with_replacement)

Pytorch torch 参考手册 Pytorch torch 参考手册