PyTorch torch.eq 函数
torch.eq 是 PyTorch 中用于逐元素相等比较的函数。
函数定义
torch.eq(input, other, out)
使用示例
实例
import torch
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([1, 2, 0, 4])
result = torch.eq(a, b)
print(result)
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([1, 2, 0, 4])
result = torch.eq(a, b)
print(result)
输出结果为:
tensor([True, True, False, True])

Pytorch torch 参考手册