Overflow encountered at GCL.augmentors.functional.random_walk_subgraph
Simplewyl2000 opened this issue · comments
Hi! I found that when I use the A.RWSampling(), the augmenters sometimes run into a crush with:
python3.8/site-packages/torch_geometric/utils/subgraph.py", line 40, in subgraph
n_mask[subset] = 1
IndexError: index 4542161131129163139 is out of bounds for dimension 0 with size 1698
Then I dive into the random_walk_subgraph function, the node_idx here is the parameter pass into subgraph as subnet, which causes the index overflow(some indexes in the node_idx is above the upper limit).
def random_walk_subgraph(edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, batch_size: int = 1000, length: int = 10):
num_nodes = edge_index.max().item() + 1
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))
start = torch.randint(0, num_nodes, size=(batch_size, ), dtype=torch.long).to(edge_index.device)
node_idx = adj.random_walk(start.flatten(), length).view(-1)
edge_index, edge_weight = subgraph(node_idx, edge_index, edge_weight)
However, I met trouble when studying how the adj.random_walk generates the wrong node_idx, then I add a checker myself:
node_idx = adj.random_walk(start.flatten(), length).view(-1)
for index, value in enumerate(node_idx):
if value >= edge_index.max().item() + 1 or value < 0:
node_idx[index] = random.randint(0, edge_index.max().item())
edge_index, edge_weight = subgraph(node_idx, edge_index, edge_weight)
But it seems not a proper solution, could you tell me why this problem happens or is there anything wrong with the adj.random_walk?
Thanks!
torch 1.10.0+cu113
torch-cluster 1.6.1
torch-geometric 1.7.0
torch-scatter 2.1.1
torch-sparse 0.6.13
torch-spline-conv 1.2.2
cuda11.3