Graph Neural Networks/en
| Article | |
|---|---|
| Topic area | deep-learning |
| Prerequisites | Neural Network, Graph Theory |
Overview
A graph neural network (GNN) is a class of Neural Network designed to operate on graph-structured data, where inputs consist of nodes connected by edges rather than fixed grids or sequences. Unlike convolutional or recurrent architectures, which exploit the regular topology of images or text, GNNs learn representations that respect arbitrary connectivity patterns and remain invariant to permutations of node ordering. They have become the standard tool for problems in which relational structure is informative, including molecular property prediction, social network analysis, recommendation systems, traffic forecasting, and knowledge graph reasoning. The unifying idea behind nearly all modern GNN variants is message passing: each node iteratively updates its representation by aggregating information from its neighbors, allowing local features to propagate across the graph and produce embeddings that capture both node attributes and structural context.[1]
Intuition and motivation
A graph $ G = (V, E) $ consists of a set of nodes $ V $, edges $ E \subseteq V \times V $, and optional features attached to each node and edge. Many real-world systems are naturally relational: atoms in a molecule, users in a social network, intersections in a road map, or entities in a Knowledge Graph. Applying a standard Multilayer Perceptron to such data requires flattening the graph into a fixed-size vector, which discards connectivity. Applying a Convolutional Neural Network requires a regular grid, which graphs lack. A GNN sidesteps both problems by computing node embeddings as a function of each node's local neighborhood, then composing layers so that each additional layer expands the receptive field by one hop.
Two properties make this work in practice. First, permutation equivariance: reordering the node indices reorders the outputs in the same way but does not change their values, which is the correct symmetry for unordered sets of neighbors. Second, locality: a node's update depends only on its immediate neighbors, so the same parameters can be reused across the entire graph regardless of size. Together these properties give GNNs a strong inductive bias for relational data, much as translation equivariance gives CNNs a strong bias for images.
Message passing formulation
The dominant abstraction is the message passing neural network (MPNN) framework.[2] Let $ h_v^{(l)} $ denote the embedding of node $ v $ at layer $ l $, and $ \mathcal{N}(v) $ its neighbors. A single layer applies three steps:
$ {\displaystyle m_v^{(l+1)} = \mathrm{AGG}^{(l)}\big(\{h_u^{(l)} : u \in \mathcal{N}(v)\}\big)} $
$ {\displaystyle h_v^{(l+1)} = \mathrm{UPD}^{(l)}\big(h_v^{(l)}, m_v^{(l+1)}\big)} $
The aggregation function $ \mathrm{AGG} $ must be permutation-invariant — typical choices are sum, mean, max, or attention-weighted sum. The update function $ \mathrm{UPD} $ is usually a small MLP or a gated unit. After $ L $ layers, each node embedding incorporates information from its $ L $-hop neighborhood. For graph-level tasks, a final readout pools all node embeddings into a single vector, often via summation, mean, or a learned attention pool.
Common architectures
Graph Convolutional Networks (GCN) define a layer as a symmetric normalization of the adjacency matrix followed by a linear transform and nonlinearity:[3]
$ {\displaystyle H^{(l+1)} = \sigma\!\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} H^{(l)} W^{(l)}\right)} $
where $ \tilde{A} = A + I $ includes self-loops, $ \tilde{D} $ is its degree matrix, and $ \sigma $ is a nonlinearity such as ReLU. GCN is essentially a mean-aggregator with degree-based weighting, derived as a first-order approximation of spectral graph convolution.
GraphSAGE replaces full-neighborhood aggregation with a fixed-size sample, allowing inductive learning on graphs too large to fit in memory and on previously unseen nodes.[4]
Graph Attention Networks (GAT) assign learnable weights to each neighbor via an Attention Mechanism, so that informative neighbors contribute more than uninformative ones:[5]
$ {\displaystyle \alpha_{vu} = \mathrm{softmax}_u\!\left(\mathrm{LeakyReLU}\big(a^\top [W h_v \,\|\, W h_u]\big)\right)} $
Graph Isomorphism Networks (GIN) use sum aggregation with an injective MLP update and provably match the discriminative power of the Weisfeiler-Leman graph isomorphism test, the strongest expressiveness achievable by standard message-passing GNNs.[6]
Training and inference
GNNs are trained with standard gradient-based optimizers such as Adam using task-specific loss functions: cross-entropy for node classification, Mean Squared Error for regression, and contrastive or margin losses for link prediction. Three task regimes are common:
- Node-level: predict a label for each node (e.g. category in a citation network). A semi-supervised setting with few labeled nodes is typical.
- Edge-level: predict whether two nodes should be connected, used in recommendation and knowledge graph completion.
- Graph-level: predict a property of the entire graph, such as molecular toxicity or solubility.
Scaling to large graphs requires care. Full-batch training stores all node embeddings in memory, which is infeasible beyond a few million nodes. Common alternatives include neighbor sampling (GraphSAGE-style), subgraph sampling (Cluster-GCN, GraphSAINT), and historical embedding caches that trade staleness for memory. At inference time, GNNs are inductive when the architecture parameterizes a function of local features rather than node identities, allowing application to unseen graphs.
Expressiveness and limitations
A standard message-passing GNN is at most as powerful as the 1-Weisfeiler-Leman test for distinguishing graphs, which means there exist non-isomorphic graphs that no MPNN can tell apart. This motivates higher-order variants such as $ k $-WL GNNs, subgraph GNNs that augment node features with structural identifiers, and equivariant transformers over graphs.
Two well-known pathologies affect deep GNNs:
- Oversmoothing: as depth increases, repeated aggregation drives node embeddings toward indistinguishable values, hurting accuracy. Residual connections, normalization, and PairNorm-style regularizers mitigate this but do not fully solve it.
- Oversquashing: information from distant nodes is compressed through narrow bottlenecks, so long-range dependencies are hard to capture in graphs with small bottleneck sets. Graph rewiring and graph transformers attempt to address this by adding shortcut edges or replacing local message passing with global attention.
GNNs also inherit standard pitfalls of deep learning: sensitivity to distribution shift, dependence on quality node features, and limited robustness to adversarial perturbations of the graph structure.
Comparisons with other models
Compared with Convolutional Neural Networks, GNNs generalize the convolution operation to irregular domains; a CNN can be viewed as a GNN on a regular grid graph with translation-invariant filters. Compared with Recurrent Neural Networks, GNNs operate on sets of neighbors rather than ordered sequences, but iterated message passing can be unrolled into a recurrent computation. The Transformer is closely related: self-attention is equivalent to a GNN operating on a fully connected graph, and recent graph transformer architectures combine local message passing with global attention plus structural positional encodings to capture both inductive bias and long-range dependencies.
Applications
GNNs have produced state-of-the-art results in domains where structure matters more than absolute features. In chemistry and materials science they predict molecular properties, accelerate density functional theory calculations, and propose synthesis pathways. In drug discovery they score binding affinities and screen candidate molecules. In recommendation, they model user-item interaction graphs at production scale. In physics they have learned simulators of particle and fluid dynamics. In code analysis they reason over abstract syntax trees and data-flow graphs. The versatility of the message-passing abstraction, combined with the prevalence of relational data, has made GNNs a core tool in modern machine learning.