PyTorch torch.cdist 函数
torch.cdist 是 PyTorch 中用于计算两组点之间欧氏距离矩阵的函数。它计算第一个输入中每个点与第二个输入中每个点之间的欧氏距离,返回一个距离矩阵。
函数定义
torch.cdist(input1, input2, p=2.0, compute_mode='use_mm_for_euclidean_dist')
使用示例
实例
import torch
# 计算两组点之间的欧氏距离
x = torch.tensor([[0, 0], [1, 1], [2, 2]]) # 3 个 2D 点
y = torch.tensor([[0, 0], [1, 0], [2, 0]]) # 3 个 2D 点
# 距离矩阵形状: (3, 3)
distances = torch.cdist(x, y)
print("点 x:")
print(x)
print("点 y:")
print(y)
print("欧氏距离矩阵:")
print(distances)
# 使用不同的 p 值(曼哈顿距离 p=1)
dist_l1 = torch.cdist(x, y, p=1.0)
print("L1 距离 (p=1):")
print(dist_l1)
# 使用 p=无穷大(切比雪夫距离)
dist_inf = torch.cdist(x, y, p=float('inf'))
print("切比雪夫距离 (p=inf):")
print(dist_inf)
# 计算两组点之间的欧氏距离
x = torch.tensor([[0, 0], [1, 1], [2, 2]]) # 3 个 2D 点
y = torch.tensor([[0, 0], [1, 0], [2, 0]]) # 3 个 2D 点
# 距离矩阵形状: (3, 3)
distances = torch.cdist(x, y)
print("点 x:")
print(x)
print("点 y:")
print(y)
print("欧氏距离矩阵:")
print(distances)
# 使用不同的 p 值(曼哈顿距离 p=1)
dist_l1 = torch.cdist(x, y, p=1.0)
print("L1 距离 (p=1):")
print(dist_l1)
# 使用 p=无穷大(切比雪夫距离)
dist_inf = torch.cdist(x, y, p=float('inf'))
print("切比雪夫距离 (p=inf):")
print(dist_inf)

Pytorch torch 参考手册