scatter_max bug: always return out-of-(upper)bound index, value associated with it is 0
Chaoqi-LIU opened this issue · comments
Hi, I'm using torch 2.1.0.post303 and torch_scatter 2.1.2 with cuda 12.2.
recall: scatter_max returns the value, and index associated with it.
The bug I encountered was the size of the src will always be included in the second return, i.e., indices, and the value associated with that index is 0.
This is my temporary fix:
max_z, argmax_z = torch_scatter.scatter_max(in_bbox_particles_L[-2, :, 2], indices)
bug_mask = argmax_z == indices.shape[0]
max_z = max_z[~bug_mask]
argmax_z = argmax_z[~bug_mask]
I tested with torch.unique to see if index with value of the size of the src was appeared, but no, so it's very likely scatter_max's fault.
beyond this, scatter_min has the same problem as well.
Do you mean that the argmax
is filled with an invalid index in case the segment is empty? This is working as designed, and your solution is the correct way to handle this downstream.
cool. thanks. didn't know it's designed to be so. 👍