graphdeeplearning / graphtransformer

Graph Transformer Architecture. Source code for "A Generalization of Transformer Networks to Graphs", DLG-AAAI'21.

Home Page:https://arxiv.org/abs/2012.09699

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Detail on softmax

DevinKreuzer opened this issue · comments

Great work!

I have a question concerning the implementation of softmax in the graph_transformer_edge_layer.py

When you define the softmax, you use the following function:

def exp(field):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
    return func

Shouldn't the attention weights/scores be scalars? From what I see, each head has an 8-dimensional score vector which you then compute .sum() on. The graph_transformer_layer.py layer does not have this .sum() function.

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func

Would appreciate any clarification on this :)

Best,
Devin

Hi @DevinKreuzer,

The .sum() is done here in graph_transformer_layer.py.

def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

@DevinKreuzer: Shouldn't the attention weights/scores be scalars? From what I see, each head has an 8-dimensional score vector

  • In graph_transformer_edge_layer.py, the process of injecting available edge features is chosen to be feature-dimension wise, i.e. implicit attention scores (per feature dimension) is multiplied with available edge features (per feature dimension), in Eqn. 12 of the paper, and implemented as:

    def func(edges):
    return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}

  • Eqn. 12 outputs a d-dim feature vector (say d is the feature dimension). This d-dim edge feature vector is critical since its passed to the edge feature pipeline (to be maintained at every layer), starting from Eqn. 10, towards Eqns. 16-18 in the paper. In Eqn.11 the features of \hat{w}_{i, j} are summed across the d-dimensions to obtain scalars, which is the .sum() that you mention in your query.

def exp(field):
def func(edges):
# clamp for softmax numerical stability
return {field: torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-5, 5))}
return func

Hope this helps for understanding the implementation.
Vijay

Closing the issue for now. Feel free to open for any (further) clarification.