PyTorch torch.multinomial 函数
torch.multinomial 是 PyTorch 中用于多项式采样的函数,从输入张量的每一行中按照概率分布进行采样。
函数定义
torch.multinomial(input, num_samples, replacement=False, generator=None)
参数说明
input- 概率分布张量,每行是一个概率分布,所有概率之和应为1num_samples- 采样数量replacement- 是否允许放回采样(默认 False)generator- 随机数生成器(可选)
使用示例
实例
import torch
# 定义概率分布
weights = torch.tensor([[0.0, 1.0], # 第二个类别概率为1
[0.5, 0.5], # 两个类别概率相等
[0.2, 0.3, 0.5]]) # 三个类别
# 不放回采样
result = torch.multinomial(weights, num_samples=2, replacement=False)
print("概率分布:")
print(weights)
print("不放回采样结果 (每行采样2个):")
print(result)
# 放回采样
result_with_replacement = torch.multinomial(weights, num_samples=5, replacement=True)
print("放回采样结果 (每行采样5个):")
print(result_with_replacement)
# 定义概率分布
weights = torch.tensor([[0.0, 1.0], # 第二个类别概率为1
[0.5, 0.5], # 两个类别概率相等
[0.2, 0.3, 0.5]]) # 三个类别
# 不放回采样
result = torch.multinomial(weights, num_samples=2, replacement=False)
print("概率分布:")
print(weights)
print("不放回采样结果 (每行采样2个):")
print(result)
# 放回采样
result_with_replacement = torch.multinomial(weights, num_samples=5, replacement=True)
print("放回采样结果 (每行采样5个):")
print(result_with_replacement)

Pytorch torch 参考手册