现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.cdist 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.cdist 是 PyTorch 中用于计算两组点之间欧氏距离矩阵的函数。它计算第一个输入中每个点与第二个输入中每个点之间的欧氏距离,返回一个距离矩阵。

函数定义

torch.cdist(input1, input2, p=2.0, compute_mode='use_mm_for_euclidean_dist')

使用示例

实例

import torch

# 计算两组点之间的欧氏距离
x = torch.tensor([[0, 0], [1, 1], [2, 2]])  # 3 个 2D 点
y = torch.tensor([[0, 0], [1, 0], [2, 0]])  # 3 个 2D 点

# 距离矩阵形状: (3, 3)
distances = torch.cdist(x, y)
print("点 x:")
print(x)
print("点 y:")
print(y)
print("欧氏距离矩阵:")
print(distances)

# 使用不同的 p 值(曼哈顿距离 p=1)
dist_l1 = torch.cdist(x, y, p=1.0)
print("L1 距离 (p=1):")
print(dist_l1)

# 使用 p=无穷大(切比雪夫距离)
dist_inf = torch.cdist(x, y, p=float('inf'))
print("切比雪夫距离 (p=inf):")
print(dist_inf)

Pytorch torch 参考手册 Pytorch torch 参考手册