Graph Self-supervised Learning with Accurate Discrepancy Learning


Authors proposed a framework called D-SLA that aims to learn the exact discrepancy between the original and the perturbed graphs.
Three major components
Learn to distinguish whether each graph is the original graph or the perturbed one.
Capture the amount of discrepancy for each perturbed graph (using edit distance)
Learn relative discrepancy with other graphs


Graph Neural Networks (GNN)

Aggregate the features from its neighbors
Combining the aggregated message
Variants of Update & Aggregate functions
Graph Convolution Network (GCN)
General convolution operation + Mean aggregation
Concatenate representations of neighbors with its own representation when updating
Graph Attention Network (GAT)
Considers the relative importance among neighboring nodes when aggregation
Graph Isomorphism Network (GIN)
Sum aggregation

Self-supervised learning for graphs (GSL)

Aims to learn a good representation of the graphs in an unsupervised manner.
→ Transfer this knowledge to downstream tasks.
Most prevalent framework for GSL
Predictive learning (PL)
Aims to learn contextual relationships by predicting sub-graphical features (nodes, edges, subgraphs)
predict the attributes of masked nodes
predict the presence of an edge or a path
predict the generative sequence, contextual property, and motifs
But predictive learning may not capture the global structures and/or semantics of graphs.
Contrastive learning (CL)
Aims to capture global level information.
Early CL learn the similarity between the entire graph and its substructure.
Others include attribute masking, edge perturbation, and subgraph sampling.
Recent CL adversarial methods generate positive examples either by adaptively removing the edges or by adjusting the attributes.
But CL may not distinguish two topologically similar graphs yet having completely different properties.
Minimize LCL=logfsim(hGi,hGj)G,GG0fsim(hGi,hG)\mathcal{L}_{CL} = - \log \frac{f_{\text{sim}} (h_{\mathcal{G}_i}, h_{\mathcal{G}_j})}{\sum_{\mathcal{G}', \mathcal{G' \neq \mathcal{G}_0}}f_{\text{sim}}(h_{\mathcal{G}_i}, h_{\mathcal{G}'})}
G0\mathcal{G}_0: original graph
Gi,Gj\mathcal{G}_i, \mathcal{G}_j: perturbed graphs
G\mathcal{G}': other graph in the same batch with the G0\mathcal{G}_0, a.k.a. negative graph
positive pair: (Gi,Gj)(\mathcal{G}_i, \mathcal{G}_j); negative pair: (Gi,G)(\mathcal{G}_i, \mathcal{G}')
fsimf_\text{sim}: similarity function between two graphs → L2L_2 distance or cosine similarity
→ similarity of positive pair \uparrow, similarity of negative pair \downarrow

Discrepancy Learning

Discriminate original vs perturbed
Perturbed graph could be semantically incorrect!
→ Embed perturbed graph apart from original.
LGD=log(eS0eS0+i1eSi) with S=fS(hG)\mathcal{L}_{GD} = - \log \Big (\frac{e^{S_0}}{e^{S_0} + \sum_{i \geq 1}e^{S_i}} \Big ) \text{ with } S = f_S(h_\mathcal{G})
large value of eS0e^{S_0} for the original graph
small value of eSie^{S_i} for the perturbed graphs
How to perturb?
Aim at perturbed graph to be semantically incorrect
Remove or add a small number of edges
Manipulate the edge set by removing existing edges + adding new edges on XE\mathcal{X}_\mathcal{E}
Mask node attributes
Randomly mask the node attributes on XV\mathcal{X}_\mathcal{V} for both original and perturbed graphs
(to make it more difficult to distinguish between them)
G0=(V,E,XV0~,XE),XV0~M(G)\mathcal{G}_0 = (\mathcal{V}, \mathcal{E}, \tilde{\mathcal{X}^0_{\mathcal{V}}}, \mathcal{X}_\mathcal{E}), \tilde{\mathcal{X}^0_{\mathcal{V}}} \sim \texttt{M}(\mathcal{G})
Gi=(V,Ei,XVi~,XEi),XVi~M(G),(Ei,XEi)P(G)\mathcal{G}_i = (\mathcal{V}, \mathcal{E}^i, \tilde{\mathcal{X}^i_{\mathcal{V}}}, \mathcal{X}^i_\mathcal{E}), \tilde{\mathcal{X}^i_{\mathcal{V}}} \sim \texttt{M}(\mathcal{G}), (\mathcal{E}^i, \mathcal{X}^i_\mathcal{E}) \sim \texttt{P}(\mathcal{G})
Personal opinion
The real usage of discriminator loss will be to push original & perturbed graph apart, while applying edit distance loss.
Discrepancy with Edit distance
How dissimilar?
Usually, we need to measure the graph distance, such as edit distance.
Edit distance: number of insertion, deletion, and substitution operations for nodes & edges to transform one graph from another. → NP hard!
But we know the exact number of perturbations for each graphs
→ use it as distance.
Ledit=i,j(dieidjej)2 with di=fdiff(hG0,hGi)\mathcal{L}_{edit} = \sum_{i, j} \Big ( \frac{d_i}{e_i} - \frac{d_j}{e_j}\Big )^2 \text{ with } d_i = f_\text{diff}(h_{\mathcal{G}_0}, h_{\mathcal{G}_i})
fdifff_{\text{diff}} measures the embedding level differences between graphs with L2 norm.
eie_i: edit distance (number of perturbations)
The trivial solution for the edit distance loss is di=dj=0d_i = d_j = 0. But because of the discriminator loss, this is not possible.
Relative discrepancy learning with other graphs
Distance between original and negative graphs in the same batch is larger than the distance between the original and perturbed graphs with some amount of margin.
Lmargin=i,jmax(0,α+didj)\mathcal{L}_{margin} = \sum_{i, j} \max (0, \alpha + d_i - d'_j)
did_i: distance between original and its perturbed graphs
djd'_j: distance between original and negative graphs
Intuitively, α+di<dj\alpha + d_i < d'_j !

Overall loss

L=LGD+λ1Ledit+λ2Lmargin\mathcal{L} = \mathcal{L}_{GD} + \lambda_1 \mathcal{L}_{edit} + \lambda_2 \mathcal{L}_{margin}