PyTorch torch.prod 函数
torch.prod 是 PyTorch 中用于返回张量所有元素的乘积的函数。
函数定义
torch.prod(input, dim, keepdim=False)
使用示例
实例
import torch
x = torch.tensor([1, 2, 3, 4])
# 返回所有元素的乘积
print("所有元素乘积:", torch.prod(x))
# 沿 dim=0 乘积
y = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("dim=0 乘积:", torch.prod(y, dim=0))
print("dim=1 乘积:", torch.prod(y, dim=1))
x = torch.tensor([1, 2, 3, 4])
# 返回所有元素的乘积
print("所有元素乘积:", torch.prod(x))
# 沿 dim=0 乘积
y = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("dim=0 乘积:", torch.prod(y, dim=0))
print("dim=1 乘积:", torch.prod(y, dim=1))
输出结果为:
所有元素乘积: tensor(24) dim=0 乘积: tensor([4, 10, 18]) dim=1 乘积: tensor([ 6, 120])

Pytorch torch 参考手册