PyTorch torch.argwhere 函数
torch.argwhere 是 PyTorch 中用于返回满足条件的元素索引的函数。它返回输入张量中值为 True(非零)的元素索引。
函数定义
torch.argwhere(input)
参数:
input(Tensor): 输入张量。
返回值:
torch.Tensor: 返回一个二维张量,每行是一个满足条件的元素索引。
使用示例
实例
import torch
# 创建一个张量
x = torch.tensor([[1, 0, 2],
[0, 3, 0],
[4, 5, 0]])
print("原始张量:")
print(x)
# 返回非零元素的索引
indices = torch.argwhere(x)
print("n非零元素的索引:")
print(indices)
# 创建一个张量
x = torch.tensor([[1, 0, 2],
[0, 3, 0],
[4, 5, 0]])
print("原始张量:")
print(x)
# 返回非零元素的索引
indices = torch.argwhere(x)
print("n非零元素的索引:")
print(indices)
输出结果为:
原始张量:
tensor([[1, 0, 2],
[0, 3, 0],
[4, 5, 0]])
非零元素的索引:
tensor([[0, 0],
[0, 2],
[1, 1],
[2, 0],
[2, 1]])
实例
import torch
# 布尔条件
x = torch.tensor([[True, False, True],
[False, True, False],
[True, True, False]])
print("布尔张量:")
print(x)
indices = torch.argwhere(x)
print("nTrue 值的索引:")
print(indices)
# 布尔条件
x = torch.tensor([[True, False, True],
[False, True, False],
[True, True, False]])
print("布尔张量:")
print(x)
indices = torch.argwhere(x)
print("nTrue 值的索引:")
print(indices)
输出结果为:
布尔张量:
tensor([[True, False, True],
[False, True, False],
[True, True, False]])
True 值的索引:
tensor([[0, 0],
[0, 2],
[1, 1],
[2, 0],
[2, 1]])
实例
import torch
# 找到大于某个值的元素
x = torch.randn(3, 4)
threshold = 0
print("原始张量:")
print(x)
# 找到大于threshold的元素索引
indices = torch.argwhere(x > threshold)
print(f"n大于 {threshold} 的元素索引:")
print(indices)
# 也可以使用 nonzero 函数,效果相同
indices2 = torch.nonzero(x > threshold)
print("n使用 nonzero 的结果:")
print(indices2)
# 找到大于某个值的元素
x = torch.randn(3, 4)
threshold = 0
print("原始张量:")
print(x)
# 找到大于threshold的元素索引
indices = torch.argwhere(x > threshold)
print(f"n大于 {threshold} 的元素索引:")
print(indices)
# 也可以使用 nonzero 函数,效果相同
indices2 = torch.nonzero(x > threshold)
print("n使用 nonzero 的结果:")
print(indices2)
输出结果为:
原始张量:
tensor([[-1.2345, 0.5678, -0.8901, 1.2345],
[ 0.3456, -0.6789, 0.9012, -0.1234],
[-0.5678, 1.2345, -0.3456, 0.7890]])
大于 0 的元素索引:
tensor([[0, 1],
[0, 3],
[1, 0],
[1, 2],
[2, 1],
[2, 3]])
使用 nonzero 的结果:
tensor([[0, 1],
[0, 3],
[1, 0],
[1, 2],
[2, 1],
[2, 3]])
实例
import torch
# 1D张量
x = torch.tensor([1, 0, 0, 4, 0, 5, 0])
indices = torch.argwhere(x)
print("1D张量的索引:")
print(indices.squeeze()) # 去掉多余的维度
# 1D张量
x = torch.tensor([1, 0, 0, 4, 0, 5, 0])
indices = torch.argwhere(x)
print("1D张量的索引:")
print(indices.squeeze()) # 去掉多余的维度
输出结果为:
1D张量的索引: tensor([0, 3, 5])
实例
import torch
# 应用:找到满足条件的元素并修改
x = torch.randn(5, 5)
# 找到所有大于0的元素索引
pos_indices = torch.argwhere(x > 0)
print("原始张量中大于0的位置:")
print(pos_indices)
# 使用这些索引修改值
for idx in pos_indices:
x[idx[0], idx[1]] = 100
print("n修改后的张量:")
print(x)
# 应用:找到满足条件的元素并修改
x = torch.randn(5, 5)
# 找到所有大于0的元素索引
pos_indices = torch.argwhere(x > 0)
print("原始张量中大于0的位置:")
print(pos_indices)
# 使用这些索引修改值
for idx in pos_indices:
x[idx[0], idx[1]] = 100
print("n修改后的张量:")
print(x)
注意:torch.argwhere 是 torch.nonzero 的别名,两者功能完全相同。它返回的是元素索引,每行对应一个满足条件的元素在原始张量中的位置。

Pytorch torch 参考手册