PyTorch torch.dist 函数
torch.dist 是 PyTorch 中用于计算两个张量之间距离的函数。
函数定义
torch.dist(input, other, p=2)
参数说明:
input- 输入张量other- 目标张量p- 范数类型,默认为 2(欧几里得距离)
使用示例
实例
import torch
# 创建两个张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
# 计算欧几里得距离(p=2)
dist = torch.dist(x, y)
print("欧几里得距离:", dist)
# 计算曼哈顿距离(p=1)
dist_l1 = torch.dist(x, y, p=1)
print("曼哈顿距离:", dist_l1)
# 创建两个张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
# 计算欧几里得距离(p=2)
dist = torch.dist(x, y)
print("欧几里得距离:", dist)
# 计算曼哈顿距离(p=1)
dist_l1 = torch.dist(x, y, p=1)
print("曼哈顿距离:", dist_l1)

Pytorch torch 参考手册