scatter or scatter_min fails when using torch.compile
gardiens opened this issue · comments
Hello,
I can't compile any model that includes scatter or scatter min from torch_scatter.
For example in this beautiful script
import torch
import torch_geometric
from torch_scatter import scatter_min
print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)
def get_x(n_points=100):
import torch
x_min = [0, 10]
y_min = [0, 10]
z_min = [0, 10]
x = torch.rand((n_points, 3))
x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]
return x
device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))
model = scatter_min
compiled_model = torch.compile(model)
expected `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)
The code fails with :
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
from user code:
line 65, in scatter_min
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,
This is currently expected, since the custom ops by torch-scatter
are not supported in torch.compile
. There exists two options:
- Disallow the use of
compile
for certain ops - Fallback to https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_. This is also what we are doing on PyG side.
For this, we added utils.scatter
, which will pick up the best computation path depending on your input arguments. Also works with torch.compile
.
For this, we added
utils.scatter
, which will pick up the best computation path depending on your input arguments. Also works withtorch.compile
.
If I understand correctly, you suggest that instead of using torch_sum or torch_scatter, we should use by default utils.scatter instead of directly calling scatter_min or scatter_max ?
Yes, if you want torch.compile
support, then this is the recommended way.