Graphs - Aug. and Training

Summary of GNN lectures

GNN Augmentation and Training#

GNN Training Pipeline#

e.g.: Latin exempli gratia (for example)

Prediction Head#

Different task levels require different prediction heads

Node-level prediction#

  • After GNN computation, we have dd-dim node embeddings: {hv(L)Rd,vG}\{{\rm h}_v^{(L)} \in {\Bbb R}^d,\forall v\in G\}
  • Suppose we want to make kk-way prediction (Classification, Regression)
  • y^v=Headnode(hv(L))=W(H)hv(L)\hat{y}_v = \text{Head}_\text{node}({\rm h}_v^{(L)}) = {\rm W}^{(H)}{\rm h}_{v}^{(L)}; (W(H)Rkd, hv(L)Rd, y^vRk)({\rm W}^{(H)} \in {\Bbb R}^{k*d},\ {\rm h}_{v}^{(L)} \in {\Bbb R}^{d}, \ \hat{y}_v \in {\Bbb R}^{k})

Edge-level prediction#

  • y^uv=Headedge(hu(L),hv(L))\hat{y}_{uv} = \text{Head}_\text{edge}({\rm h}_u^{(L)}, {\rm h}_v^{(L)})

    1. Headedge(hu(L),hv(L))=Linear(Concat(hu(L),hv(L)))\text{Head}_\text{edge}({\rm h}_u^{(L)}, {\rm h}_v^{(L)}) = \text{\small Linear}(\text{\small Concat}({\rm h}_u^{(L)}, {\rm h}_v^{(L)})); (y^uvRk)(\hat{y}_{uv} \in {\Bbb R}^{k})

    2. Headedge(hu(L),hv(L))=(hu(L))hv(L)\text{Head}_\text{edge}({\rm h}_u^{(L)}, {\rm h}_v^{(L)}) = ({\rm h}_u^{(L)})^\top{\rm h}_v^{(L)}; ((1,d)×(d,1), y^uvR1)({\small (1,d)\times(d,1)},\ \hat{y}_{uv} \in {\Bbb R}^{1})
      (1-way prediction (e.g., predict the existence of an edge))

  • Applying to kk-way prediction:

    y^uv(1)=(hu(L))W(1)hv(L)y^uv(k)=(hu(L))W(k)hv(L)y^uv=Concat(y^uv(1),,y^uv(k))Rk  \begin{aligned}
      \hat{y}_{uv}^{(1)} = ({\rm h}_u^{(L)}&)^\top{\rm W}^{(1)}{\rm h}_v^{(L)}\\
      &\cdots\\
      \hat{y}_{uv}^{(k)} = ({\rm h}_u^{(L)}&)^\top{\rm W}^{(k)}{\rm h}_v^{(L)}\\
      \\
      \hat{y}_{uv} = \text{\small Concat}(\hat{y}_{uv}^{(1)},&\cdots,\hat{y}_{uv}^{(k)}) \in {\Bbb R}^k
      \end{aligned}

Graph-level prediction#

  • y^G=Headgraph({hv(L)Rd,vG})\hat{y}_{G} = \text{Head}_\text{graph}(\{{\rm h}_v^{(L)} \in {\Bbb R}^d, \forall v \in G\})

  • Headgraph\text{Head}_\text{graph} is possible in [mean,max,sum][\text{\small mean}, \text{\small max}, \text{\small sum}].

  • But! Simple global pooling over a (large) graph will lose information.

    • A solution: Let’s aggregate all the node embeddings hierarchically

DiffPool#

Ying et al., Hierarchical Graph Representation Learning with Differentiable Pooling

  • Hierarchically pool node embeddings
  • Leverage 2 independent GNNs at each level
    • GNN A: Compute node embeddings
    • GNN B: Compute the cluster that a node belongs to
  • GNNs A and B at each level can be executed in parallel
  • For each Pooling layer
    • Use clustering assignments from GNN B to aggregate node embeddings generated by GNN A
    • Create a single new node for each cluster, maintaining edges between clusters to generated a new pooled network
  • Jointly train GNN A and GNN B

Training#

  • Supervised learning

    Labels come from external sources

    • Node labels yvy_v: in a citation network, which subject area does a node belong to
    • Edge labels yuvy_{uv}: in a transaction network, whether an edge is fraudulent
    • Graph labels yGy_G: among molecular graphs, the drug likeness of graphs

    Advice: Reduce your task to node / edge / graph labels, since they are easy to work with
    e.g., We knew some nodes form a cluster. We can treat the cluster that a node belongs to as a node label

  • Unsupervised learning (Self-supervised learning)

    we can find supervision signals within the graph.

    • Node-level yvy_v: Such as clustering coefficient, PageRank, …
    • Edge-level yuvy_{uv}: Hide the edge between two nodes, predict if there should be a link
    • Graph-level yGy_G: For example, predict if two graphs are isomorphic

Sometimes the differences are blurry

  • We still have “supervision” in unsupervised learning
    • e.g., train a GNN to predict node clustering coefficient
  • An alternative name for “unsupervised” is “self-supervised”

Data Splitting#

  1. Transductive setting: The entire graph can be observed in all dataset splits, we only split the labels

    • At training time, we compute embeddings using the entire graph, and train using node 1&2’s labels
    • At validation time, we compute embeddings using the entire graph, and evaluate on node 3&4’s labels
    • Only applicable to node / edge prediction tasks
    • (training / validation / test) sets are on the same graph
  2. Inductive setting: Each split can only observe the graph(s) within the split, we break the edges between splits to get multiple graphs. A successful model should generalize to unseen graphs.

    • At training time, we compute embeddings using the graph over node 1&2, and train using node 1&2’s labels
    • At validation time, we compute embeddings using the graph over node 3&4, and evaluate on node 3&4’s labels
    • Applicable to node / edge / graph tasks
    • (training / validation / test sets) are on different graphs

Data Splitting in Tasks#

  • Node Classification
    • Transductive node classification
    • Inductive node classification
  • Graph Classification
    • Only the inductive setting is well defined for graph classification
  • Link Prediction
    1. Assign 2 types of edges in the original graph
      • Message edges: Used for GNN message passing
      • Supervision edges: Use for computing objectives (will not be fed into GNN)
    2. Split edges into train / validation / test
      1. Option 1: Inductive link prediction split

      2. Option 2: Transductive link prediction split

        • By definition of “transductive”, the entire graph can be observed in all dataset splits
        • But since edges are both part of graph structure and the supervision, we need to hold out validation / test edges