PyTorch torch.logcumsumexp 函数
torch.logcumsumexp 是 PyTorch 中用于计算累积和指数对数的函数。它先对元素求指数,然后累积求和,最后取对数,可以避免数值溢出。
函数定义
torch.logcumsumexp(input, dim)
参数说明:
input: 输入张量dim: 累积求和的维度
使用示例
实例
import torch
# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
# 计算累积和的指数对数
y = torch.logcumsumexp(x, dim=0)
print(y)
# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
# 计算累积和的指数对数
y = torch.logcumsumexp(x, dim=0)
print(y)
输出结果为:
tensor([1.0000, 2.3133, 3.1000])
实例
import torch
# 验证:log(exp(1) + exp(2)) = log(exp(1) + exp(2))
# log(e^1 + e^2) = log(e^1 + e^2) ≈ 2.3133
x = torch.tensor([1.0, 2.0])
y = torch.logcumsumexp(x, dim=0)
# 对比直接计算
import math
direct = math.log(math.exp(1) + math.exp(2))
print("logcumsumexp:", y[1].item())
print("direct:", direct)
# 验证:log(exp(1) + exp(2)) = log(exp(1) + exp(2))
# log(e^1 + e^2) = log(e^1 + e^2) ≈ 2.3133
x = torch.tensor([1.0, 2.0])
y = torch.logcumsumexp(x, dim=0)
# 对比直接计算
import math
direct = math.log(math.exp(1) + math.exp(2))
print("logcumsumexp:", y[1].item())
print("direct:", direct)
输出结果为:
logcumsumexp: 2.313261505126953 direct: 2.313261505126953

Pytorch torch 参考手册