现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.combinations 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.combinations 是 PyTorch 中用于计算输入张量元素的所有 r 元组合的函数。它返回输入张量中所有可能的长度为 r 的组合。

函数定义

torch.combinations(input, r=2, with_replacement=False)

使用示例

实例

import torch

# 计算所有 2 元组合
x = torch.tensor([1, 2, 3, 4])
result = torch.combinations(x, r=2)
print("输入:", x)
print("2 元组合:")
print(result)
# tensor([[1, 2],
#         [1, 3],
#         [1, 4],
#         [2, 3],
#         [2, 4],
#         [3, 4]])

# 3 元组合
result3 = torch.combinations(x, r=3)
print("3 元组合:")
print(result3)

# 带放回的组合 (with_replacement=True)
result_with_replacement = torch.combinations(x, r=2, with_replacement=True)
print("带放回的 2 元组合:")
print(result_with_replacement)

Pytorch torch 参考手册 Pytorch torch 参考手册