rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations

Home Page:https://pytorch-scatter.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

scatter or scatter_min fails when using torch.compile

gardiens opened this issue · comments

Hello,

I can't compile any model that includes scatter or scatter min from torch_scatter.
For example in this beautiful script

  import torch
import torch_geometric
from torch_scatter import scatter_min

print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)


def get_x(n_points=100):  
    import torch

    x_min = [0, 10]
    y_min = [0, 10]
    z_min = [0, 10]

    x = torch.rand((n_points, 3))
    x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
    x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
    x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]

    return x


device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))

model = scatter_min
compiled_model = torch.compile(model)

expected  `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)

The code fails with :

 torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
 line 65, in scatter_min
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)

My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,

This is currently expected, since the custom ops by torch-scatter are not supported in torch.compile. There exists two options:

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

If I understand correctly, you suggest that instead of using torch_sum or torch_scatter, we should use by default utils.scatter instead of directly calling scatter_min or scatter_max ?

Yes, if you want torch.compile support, then this is the recommended way.