现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.bernoulli 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.bernoulli 是 PyTorch 中用于从伯努利分布生成随机数的函数。

函数定义

torch.bernoulli(input, *, generator=None, out=None)
torch.bernoulli(p, size, *, generator=None, out=None)

参数说明

  • input - 概率值或包含概率的张量(每个元素表示对应位置为1的概率)
  • p - 概率值(当 input 不是张量时使用)
  • size - 输出张量的形状
  • generator - 随机数生成器(可选)
  • out - 输出张量(可选)

使用示例

实例

import torch

# 使用概率张量生成伯努利分布随机数
probs = torch.tensor([0.1, 0.5, 0.9])
result = torch.bernoulli(probs)
print("概率张量:", probs)
print("伯努利采样结果:", result)

# 使用固定概率生成随机张量
result2 = torch.bernoulli(0.5, (3, 3))
print("固定概率 0.5 生成的 3x3 随机张量:")
print(result2)

Pytorch torch 参考手册 Pytorch torch 参考手册