Path: blob/master/examples/graph/ipynb/gat_node_classification.ipynb
3508 views
Graph attention network (GAT) for node classification
Author: akensert
Date created: 2021/09/13
Last modified: 2021/12/26
Description: An implementation of a Graph Attention Network (GAT) for node classification.
Introduction
Graph neural networks is the preferred neural network architecture for processing data structured as graphs (for example, social networks or molecule structures), yielding better results than fully-connected networks or convolutional networks.
In this tutorial, we will implement a specific graph neural network known as a Graph Attention Network (GAT) to predict labels of scientific papers based on what type of papers cite them (using the Cora dataset).
References
For more information on GAT, see the original paper Graph Attention Networks as well as DGL's Graph Attention Networks documentation.
Import packages
Obtain the dataset
The preparation of the Cora dataset follows that of the Node classification with Graph Neural Networks tutorial. Refer to this tutorial for more details on the dataset and exploratory data analysis. In brief, the Cora dataset consists of two files: cora.cites
which contains directed links (citations) between papers; and cora.content
which contains features of the corresponding papers and one of seven labels (the subject of the paper).
Split the dataset
Prepare the graph data
Build the model
GAT takes as input a graph (namely an edge tensor and a node feature tensor) and outputs [updated] node states. The node states are, for each target node, neighborhood aggregated information of N-hops (where N is decided by the number of layers of the GAT). Importantly, in contrast to the graph convolutional network (GCN) the GAT makes use of attention mechanisms to aggregate information from neighboring nodes (or source nodes). In other words, instead of simply averaging/summing node states from source nodes (source papers) to the target node (target papers), GAT first applies normalized attention scores to each source node state and then sums.
(Multi-head) graph attention layer
The GAT model implements multi-head graph attention layers. The MultiHeadGraphAttention
layer is simply a concatenation (or averaging) of multiple graph attention layers (GraphAttention
), each with separate learnable weights W
. The GraphAttention
layer does the following:
Consider inputs node states h^{l}
which are linearly transformed by W^{l}
, resulting in z^{l}
.
For each target node:
Computes pair-wise attention scores
a^{l}^{T}(z^{l}_{i}||z^{l}_{j})
for allj
, resulting ine_{ij}
(for allj
).||
denotes a concatenation,_{i}
corresponds to the target node, and_{j}
corresponds to a given 1-hop neighbor/source node.Normalizes
e_{ij}
via softmax, so as the sum of incoming edges' attention scores to the target node (sum_{k}{e_{norm}_{ik}}
) will add up to 1.Applies attention scores
e_{norm}_{ij}
toz_{j}
and adds it to the new target node stateh^{l+1}_{i}
, for allj
.
Implement training logic with custom train_step
, test_step
, and predict_step
methods
Notice, the GAT model operates on the entire graph (namely, node_states
and edges
) in all phases (training, validation and testing). Hence, node_states
and edges
are passed to the constructor of the keras.Model
and used as attributes. The difference between the phases are the indices (and labels), which gathers certain outputs (tf.gather(outputs, indices)
).
Train and evaluate
Predict (probabilities)
Conclusions
The results look OK! The GAT model seems to correctly predict the subjects of the papers, based on what they cite, about 80% of the time. Further improvements could be made by fine-tuning the hyper-parameters of the GAT. For instance, try changing the number of layers, the number of hidden units, or the optimizer/learning rate; add regularization (e.g., dropout); or modify the preprocessing step. We could also try to implement self-loops (i.e., paper X cites paper X) and/or make the graph undirected.