现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.logcumsumexp 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.logcumsumexp 是 PyTorch 中用于计算累积和指数对数的函数。它先对元素求指数,然后累积求和,最后取对数,可以避免数值溢出。

函数定义

torch.logcumsumexp(input, dim)

参数说明:

  • input: 输入张量
  • dim: 累积求和的维度

使用示例

实例

import torch

# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])

# 计算累积和的指数对数
y = torch.logcumsumexp(x, dim=0)
print(y)

输出结果为:

tensor([1.0000, 2.3133, 3.1000])

实例

import torch

# 验证:log(exp(1) + exp(2)) = log(exp(1) + exp(2))
# log(e^1 + e^2) = log(e^1 + e^2) ≈ 2.3133
x = torch.tensor([1.0, 2.0])
y = torch.logcumsumexp(x, dim=0)

# 对比直接计算
import math
direct = math.log(math.exp(1) + math.exp(2))
print("logcumsumexp:", y[1].item())
print("direct:", direct)

输出结果为:

logcumsumexp: 2.313261505126953
direct: 2.313261505126953

Pytorch torch 参考手册 Pytorch torch 参考手册