Graph-COM / GSAT

[ICML 2022] Graph Stochastic Attention (GSAT) for interpretable and generalizable graph learning.

Home Page:https://arxiv.org/abs/2201.12987

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

directed edge weights on undirected graphs

simoons95 opened this issue · comments

Hello,

First of all, thank you for your paper and your code, it is a pleasure to work with it.

However, I have a question about the following line :

edge_att = (att + transpose(data.edge_index, att, nodesize, nodesize, coalesced=False)[1]) / 2

When I run it, this line does not seem to do anything more than edge_att = (att + att) / 2.
As a result, edge weights are different depending on the direction of the edge (0 to 1 != 1 to 0).
Have I missed anything?

Thanks a lot for spotting and reporting this issue. After checking the code quickly, I find this is indeed a bug caused by PR #3. These issues are caused by edge_index that are not properly sorted, and I will fix this soon. Thank you very much!

Can it really happen that input indices are not sorted (reason of PR3), given they come from a dataloader?

If I remember correctly, I did PR3 because I found data.edge_index is not sorted for dataset mutag though it's from a dataloader, i.e., it gives something like [[0, 1, 0, 2, 0], [1, 0, 2, 0, 3]], where the src tensor should be [0, 0, 0, 1, 2] if it's sorted.

I just fixed this issue by PR #7, and now the code should work properly. Thanks again for spotting this issue, and feel free to let us know if you encounter any more issues!

I see you go to cpu but never come back to gpu, which may lead to some bugs.
Maybe the following function could help:

from torch_geometric.utils import sort_edge_index

def reorder_like(from_edge_index, to_edge_index, values):
    from_edge_index, values = sort_edge_index(from_edge_index, values)
    ranking_score = to_edge_index[0] * (to_edge_index.max()+1) + to_edge_index[1]
    ranking = ranking_score.argsort().argsort()
    if not (from_edge_index[:, ranking] == to_edge_index).all():
        raise ValueError("Edges in from_edge_index and to_edge_index are different, impossible to match both.")
    return values[ranking]

You can use it like this:

                trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
                trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
                edge_att = (att + trans_val_perm) / 2

Thanks a lot for the suggestion! I created a new PR #8 and updated the code as you suggested. :)