现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.cumprod 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.cumprod 是 PyTorch 中用于计算累积积的函数。它返回沿指定维度的累积乘积,即从开始到当前位置所有元素的乘积。

函数定义

torch.cumprod(input, dim, dtype=None)

使用示例

实例

import torch

# 计算累积积
x = torch.tensor([1, 2, 3, 4, 5])
result = torch.cumprod(x, dim=0)
print("输入:", x)
print("累积积:", result)
# 输出: tensor([1, 2, 6, 24, 120])
# 说明: 1, 1*2=2, 1*2*3=6, 1*2*3*4=24, 1*2*3*4*5=120

# 2 维张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
# 按列累积
result_col = torch.cumprod(x, dim=0)
print("n按列累积积:")
print(result_col)
# tensor([[ 1,  2,  3],
#         [ 4, 10, 18]])

# 按行累积
result_row = torch.cumprod(x, dim=1)
print("n按行累积积:")
print(result_row)
# tensor([[  1,   2,   6],
#         [  4,  20, 120]])

Pytorch torch 参考手册 Pytorch torch 参考手册