rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations

Home Page:https://pytorch-scatter.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

scatter_max bug: always return out-of-(upper)bound index, value associated with it is 0

Chaoqi-LIU opened this issue · comments

Hi, I'm using torch 2.1.0.post303 and torch_scatter 2.1.2 with cuda 12.2.

recall: scatter_max returns the value, and index associated with it.

The bug I encountered was the size of the src will always be included in the second return, i.e., indices, and the value associated with that index is 0.

This is my temporary fix:

  max_z, argmax_z = torch_scatter.scatter_max(in_bbox_particles_L[-2, :, 2], indices)
  bug_mask = argmax_z == indices.shape[0]
  max_z = max_z[~bug_mask]                 
  argmax_z = argmax_z[~bug_mask]

I tested with torch.unique to see if index with value of the size of the src was appeared, but no, so it's very likely scatter_max's fault.

beyond this, scatter_min has the same problem as well.

Do you mean that the argmax is filled with an invalid index in case the segment is empty? This is working as designed, and your solution is the correct way to handle this downstream.

cool. thanks. didn't know it's designed to be so. 👍