在pytorch Argmax中发生冲突时的索引选择


苞米地里的蒙面妖
2025-03-07 11:51:59 (26天前)


我一直在努力学习张量操作,而这个操作让我陷入了困境。
让我们说我有一个张量:

t = torch.tensor([
[1,0,0,2],
[0,3,3,0],
[4,0,0,5]


2 条回复
  1. 0# 取之 | 2019-08-31 10-32


    这是一个很好的问题,我自己偶然发现了几次。最简单的答案是,没有任何保证

    torch.argmax

    (要么

    torch.max(x, dim=k)

    ,当指定dim时也返回索引)将一致地返回相同的索引。相反,它会回来

    任何有效的索引

    到argmax值,可能是随机的。如

    这个帖子在官方论坛上

    讨论,这被认为是理想的行为。 (我知道我刚才读过另一个线程,这使得这个更明确,但我再也找不到了)。



    话虽如此,由于这种行为对我的用例来说是不可接受的,我写了下面的函数,它们会找到最左边和最右边的索引(请注意

    condition

    是传入的函数对象):




    1. def __consistent_args(input, condition, indices):
      assert len(input.shape) == 2, only works for batch x dim tensors along the dim axis
      mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
      return torch.argmax(mask, dim=1)

    2. def consistent_find_leftmost(input, condition):
      indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
      return __consistent_args(input, condition, indices)

    3. def consistent_find_rightmost(input, condition):
      indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
      return __consistent_args(input, condition, indices)

    4. one example:

      consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)

    5. will return:

      tensor([6])

      </code>


    希望他们能帮忙! (哦,如果你有更好的实现,请告诉我)


登录 后才能参与评论