PyTorch torch.bernoulli 函数
torch.bernoulli 是 PyTorch 中用于从伯努利分布生成随机数的函数。
函数定义
torch.bernoulli(input, *, generator=None, out=None) torch.bernoulli(p, size, *, generator=None, out=None)
参数说明
input- 概率值或包含概率的张量(每个元素表示对应位置为1的概率)p- 概率值(当 input 不是张量时使用)size- 输出张量的形状generator- 随机数生成器(可选)out- 输出张量(可选)
使用示例
实例
import torch
# 使用概率张量生成伯努利分布随机数
probs = torch.tensor([0.1, 0.5, 0.9])
result = torch.bernoulli(probs)
print("概率张量:", probs)
print("伯努利采样结果:", result)
# 使用固定概率生成随机张量
result2 = torch.bernoulli(0.5, (3, 3))
print("固定概率 0.5 生成的 3x3 随机张量:")
print(result2)
# 使用概率张量生成伯努利分布随机数
probs = torch.tensor([0.1, 0.5, 0.9])
result = torch.bernoulli(probs)
print("概率张量:", probs)
print("伯努利采样结果:", result)
# 使用固定概率生成随机张量
result2 = torch.bernoulli(0.5, (3, 3))
print("固定概率 0.5 生成的 3x3 随机张量:")
print(result2)

Pytorch torch 参考手册