现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.argwhere 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.argwhere 是 PyTorch 中用于返回满足条件的元素索引的函数。它返回输入张量中值为 True(非零)的元素索引。

函数定义

torch.argwhere(input)

参数:

  • input (Tensor): 输入张量。

返回值:

  • torch.Tensor: 返回一个二维张量,每行是一个满足条件的元素索引。

使用示例

实例

import torch

# 创建一个张量
x = torch.tensor([[1, 0, 2],
                  [0, 3, 0],
                  [4, 5, 0]])

print("原始张量:")
print(x)

# 返回非零元素的索引
indices = torch.argwhere(x)

print("n非零元素的索引:")
print(indices)

输出结果为:

原始张量:
tensor([[1, 0, 2],
        [0, 3, 0],
        [4, 5, 0]])

非零元素的索引:
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [2, 0],
        [2, 1]])

实例

import torch

# 布尔条件
x = torch.tensor([[True, False, True],
                  [False, True, False],
                  [True, True, False]])

print("布尔张量:")
print(x)

indices = torch.argwhere(x)
print("nTrue 值的索引:")
print(indices)

输出结果为:

布尔张量:
tensor([[True, False, True],
        [False, True, False],
        [True, True, False]])

True 值的索引:
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [2, 0],
        [2, 1]])

实例

import torch

# 找到大于某个值的元素
x = torch.randn(3, 4)
threshold = 0

print("原始张量:")
print(x)

# 找到大于threshold的元素索引
indices = torch.argwhere(x > threshold)

print(f"n大于 {threshold} 的元素索引:")
print(indices)

# 也可以使用 nonzero 函数,效果相同
indices2 = torch.nonzero(x > threshold)
print("n使用 nonzero 的结果:")
print(indices2)

输出结果为:

原始张量:
tensor([[-1.2345,  0.5678, -0.8901,  1.2345],
        [ 0.3456, -0.6789,  0.9012, -0.1234],
        [-0.5678,  1.2345, -0.3456,  0.7890]])

大于 0 的元素索引:
tensor([[0, 1],
        [0, 3],
        [1, 0],
        [1, 2],
        [2, 1],
        [2, 3]])

使用 nonzero 的结果:
tensor([[0, 1],
        [0, 3],
        [1, 0],
        [1, 2],
        [2, 1],
        [2, 3]])

实例

import torch

# 1D张量
x = torch.tensor([1, 0, 0, 4, 0, 5, 0])

indices = torch.argwhere(x)
print("1D张量的索引:")
print(indices.squeeze())  # 去掉多余的维度

输出结果为:

1D张量的索引:
tensor([0, 3, 5])

实例

import torch

# 应用:找到满足条件的元素并修改
x = torch.randn(5, 5)

# 找到所有大于0的元素索引
pos_indices = torch.argwhere(x > 0)

print("原始张量中大于0的位置:")
print(pos_indices)

# 使用这些索引修改值
for idx in pos_indices:
    x[idx[0], idx[1]] = 100

print("n修改后的张量:")
print(x)

注意:torch.argwheretorch.nonzero 的别名,两者功能完全相同。它返回的是元素索引,每行对应一个满足条件的元素在原始张量中的位置。


Pytorch torch 参考手册 Pytorch torch 参考手册