PyTorch torch.combinations 函数
torch.combinations 是 PyTorch 中用于计算输入张量元素的所有 r 元组合的函数。它返回输入张量中所有可能的长度为 r 的组合。
函数定义
torch.combinations(input, r=2, with_replacement=False)
使用示例
实例
import torch
# 计算所有 2 元组合
x = torch.tensor([1, 2, 3, 4])
result = torch.combinations(x, r=2)
print("输入:", x)
print("2 元组合:")
print(result)
# tensor([[1, 2],
# [1, 3],
# [1, 4],
# [2, 3],
# [2, 4],
# [3, 4]])
# 3 元组合
result3 = torch.combinations(x, r=3)
print("3 元组合:")
print(result3)
# 带放回的组合 (with_replacement=True)
result_with_replacement = torch.combinations(x, r=2, with_replacement=True)
print("带放回的 2 元组合:")
print(result_with_replacement)
# 计算所有 2 元组合
x = torch.tensor([1, 2, 3, 4])
result = torch.combinations(x, r=2)
print("输入:", x)
print("2 元组合:")
print(result)
# tensor([[1, 2],
# [1, 3],
# [1, 4],
# [2, 3],
# [2, 4],
# [3, 4]])
# 3 元组合
result3 = torch.combinations(x, r=3)
print("3 元组合:")
print(result3)
# 带放回的组合 (with_replacement=True)
result_with_replacement = torch.combinations(x, r=2, with_replacement=True)
print("带放回的 2 元组合:")
print(result_with_replacement)

Pytorch torch 参考手册