CRIPAC-DIG / GCA

[WWW 2021] Source code for "Graph Contrastive Learning with Adaptive Augmentation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about code

GeniusYx opened this issue · comments

Sorry to bother you, I am confused that why edge_weights divided by edge_weights.mean()?

def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
edge_weights = edge_weights / edge_weights.mean() * p
edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)
return edge_index[:, sel_mask]