PyGCL / PyGCL

PyGCL: A PyTorch Library for Graph Contrastive Learning

Home Page:https://PyGCL.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Overflow encountered at GCL.augmentors.functional.random_walk_subgraph

Simplewyl2000 opened this issue · comments

commented

Hi! I found that when I use the A.RWSampling(), the augmenters sometimes run into a crush with:

python3.8/site-packages/torch_geometric/utils/subgraph.py", line 40, in subgraph
    n_mask[subset] = 1
IndexError: index 4542161131129163139 is out of bounds for dimension 0 with size 1698

Then I dive into the random_walk_subgraph function, the node_idx here is the parameter pass into subgraph as subnet, which causes the index overflow(some indexes in the node_idx is above the upper limit).

def random_walk_subgraph(edge_index: torch.LongTensor, edge_weight: Optional[torch.FloatTensor] = None, batch_size: int = 1000, length: int = 10):
    num_nodes = edge_index.max().item() + 1

    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))

    start = torch.randint(0, num_nodes, size=(batch_size, ), dtype=torch.long).to(edge_index.device)
    node_idx = adj.random_walk(start.flatten(), length).view(-1)
    edge_index, edge_weight = subgraph(node_idx, edge_index, edge_weight)

However, I met trouble when studying how the adj.random_walk generates the wrong node_idx, then I add a checker myself:

 node_idx = adj.random_walk(start.flatten(), length).view(-1)
 for index, value in enumerate(node_idx):
     if value >= edge_index.max().item() + 1 or value < 0:
         node_idx[index] = random.randint(0, edge_index.max().item())
    edge_index, edge_weight = subgraph(node_idx, edge_index, edge_weight)

But it seems not a proper solution, could you tell me why this problem happens or is there anything wrong with the adj.random_walk?

Thanks!

torch 1.10.0+cu113
torch-cluster 1.6.1
torch-geometric 1.7.0
torch-scatter 2.1.1
torch-sparse 0.6.13
torch-spline-conv 1.2.2
cuda11.3

@SXKDZ @Linyxus @dongkwan-kim @zlpure @AzureLeon1