PyTorch torch.cummin 函数
torch.cummin 是 PyTorch 中用于计算累积最小值的函数。它返回沿指定维度的累积最小值以及对应的索引。
函数定义
torch.cummin(input, dim, dtype=None)
使用示例
实例
import torch
# 计算累积最小值
x = torch.tensor([3, 1, 4, 1, 5, 9, 2, 6])
values, indices = torch.cummin(x, dim=0)
print("输入:", x)
print("累积最小值:", values)
print("最小值索引:", indices)
# 2 维张量
x = torch.tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]], dtype=torch.float32)
# 按列累积
values_col, indices_col = torch.cummin(x, dim=0)
print("n按列累积:")
print("累积最小值:")
print(values_col)
print("索引:")
print(indices_col)
# 按行累积
values_row, indices_row = torch.cummin(x, dim=1)
print("n按行累积:")
print("累积最小值:")
print(values_row)
# 计算累积最小值
x = torch.tensor([3, 1, 4, 1, 5, 9, 2, 6])
values, indices = torch.cummin(x, dim=0)
print("输入:", x)
print("累积最小值:", values)
print("最小值索引:", indices)
# 2 维张量
x = torch.tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]], dtype=torch.float32)
# 按列累积
values_col, indices_col = torch.cummin(x, dim=0)
print("n按列累积:")
print("累积最小值:")
print(values_col)
print("索引:")
print(indices_col)
# 按行累积
values_row, indices_row = torch.cummin(x, dim=1)
print("n按行累积:")
print("累积最小值:")
print(values_row)

Pytorch torch 参考手册