PyTorch torch.where 函数
torch.where 是 PyTorch 中用于根据条件返回元素的函数。
函数定义
torch.where(condition, input, other)
使用示例
实例
import torch
condition = torch.tensor([[True, False], [False, True]])
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[10, 20], [30, 40]])
result = torch.where(condition, x, y)
print(result)
condition = torch.tensor([[True, False], [False, True]])
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[10, 20], [30, 40]])
result = torch.where(condition, x, y)
print(result)
输出结果为:
tensor([[ 1, 20],
[30, 4]])

Pytorch torch 参考手册