Path: blob/master/examples/graph/ipynb/gnn_citations.ipynb
3508 views
Node Classification with Graph Neural Networks
Author: Khalid Salama
Date created: 2021/05/30
Last modified: 2021/05/30
Description: Implementing a graph neural network model for predicting the topic of a paper given its citations.
Introduction
Many datasets in various machine learning (ML) applications have structural relationships between their entities, which can be represented as graphs. Such application includes social and communication networks analysis, traffic prediction, and fraud detection. Graph representation Learning aims to build and train models for graph datasets to be used for a variety of ML tasks.
This example demonstrate a simple implementation of a Graph Neural Network (GNN) model. The model is used for a node prediction task on the Cora dataset to predict the subject of a paper given its words and citations network.
Note that, we implement a Graph Convolution Layer from scratch to provide better understanding of how they work. However, there is a number of specialized TensorFlow-based libraries that provide rich GNN APIs, such as Spectral, StellarGraph, and GraphNets.
Setup
Prepare the Dataset
The Cora dataset consists of 2,708 scientific papers classified into one of seven classes. The citation network consists of 5,429 links. Each paper has a binary word vector of size 1,433, indicating the presence of a corresponding word.
Download the dataset
The dataset has two tap-separated files: cora.cites
and cora.content
.
The
cora.cites
includes the citation records with two columns:cited_paper_id
(target) andciting_paper_id
(source).The
cora.content
includes the paper content records with 1,435 columns:paper_id
,subject
, and 1,433 binary features.
Let's download the dataset.
Process and visualize the dataset
Then we load the citations data into a Pandas DataFrame.
Now we display a sample of the citations
DataFrame. The target
column includes the paper ids cited by the paper ids in the source
column.
Now let's load the papers data into a Pandas DataFrame.
Now we display a sample of the papers
DataFrame. The DataFrame includes the paper_id
and the subject
columns, as well as 1,433 binary column representing whether a term exists in the paper or not.
Let's display the count of the papers in each subject.
We convert the paper ids and the subjects into zero-based indices.
Now let's visualize the citation graph. Each node in the graph represents a paper, and the color of the node corresponds to its subject. Note that we only show a sample of the papers in the dataset.
Split the dataset into stratified train and test sets
Implement Train and Evaluate Experiment
This function compiles and trains an input model using the given training data.
This function displays the loss and accuracy curves of the model during training.
Implement Feedforward Network (FFN) Module
We will use this module in the baseline and the GNN models.
Build a Baseline Neural Network Model
Prepare the data for the baseline model
Implement a baseline classifier
We add five FFN blocks with skip connections, so that we generate a baseline model with roughly the same number of parameters as the GNN models to be built later.
Train the baseline classifier
Let's plot the learning curves.
Now we evaluate the baseline model on the test data split.
Examine the baseline model predictions
Let's create new data instances by randomly generating binary word vectors with respect to the word presence probabilities.
Now we show the baseline model predictions given these randomly generated instances.
Build a Graph Neural Network Model
Prepare the data for the graph model
Preparing and loading the graphs data into the model for training is the most challenging part in GNN models, which is addressed in different ways by the specialised libraries. In this example, we show a simple approach for preparing and using graph data that is suitable if your dataset consists of a single graph that fits entirely in memory.
The graph data is represented by the graph_info
tuple, which consists of the following three elements:
node_features
: This is a[num_nodes, num_features]
NumPy array that includes the node features. In this dataset, the nodes are the papers, and thenode_features
are the word-presence binary vectors of each paper.edges
: This is[num_edges, num_edges]
NumPy array representing a sparse adjacency matrix of the links between the nodes. In this example, the links are the citations between the papers.edge_weights
(optional): This is a[num_edges]
NumPy array that includes the edge weights, which quantify the relationships between nodes in the graph. In this example, there are no weights for the paper citations.
Implement a graph convolution layer
We implement a graph convolution module as a Keras Layer. Our GraphConvLayer
performs the following steps:
Prepare: The input node representations are processed using a FFN to produce a message. You can simplify the processing by only applying linear transformation to the representations.
Aggregate: The messages of the neighbours of each node are aggregated with respect to the
edge_weights
using a permutation invariant pooling operation, such as sum, mean, and max, to prepare a single aggregated message for each node. See, for example, tf.math.unsorted_segment_sum APIs used to aggregate neighbour messages.Update: The
node_repesentations
andaggregated_messages
—both of shape[num_nodes, representation_dim]
— are combined and processed to produce the new state of the node representations (node embeddings). Ifcombination_type
isgru
, thenode_repesentations
andaggregated_messages
are stacked to create a sequence, then processed by a GRU layer. Otherwise, thenode_repesentations
andaggregated_messages
are added or concatenated, then processed using a FFN.
The technique implemented use ideas from Graph Convolutional Networks, GraphSage, Graph Isomorphism Network, Simple Graph Networks, and Gated Graph Sequence Neural Networks. Two other key techniques that are not covered are Graph Attention Networks and Message Passing Neural Networks.
Implement a graph neural network node classifier
The GNN classification model follows the Design Space for Graph Neural Networks approach, as follows:
Apply preprocessing using FFN to the node features to generate initial node representations.
Apply one or more graph convolutional layer, with skip connections, to the node representation to produce node embeddings.
Apply post-processing using FFN to the node embeddings to generate the final node embeddings.
Feed the node embeddings in a Softmax layer to predict the node class.
Each graph convolutional layer added captures information from a further level of neighbours. However, adding many graph convolutional layer can cause oversmoothing, where the model produces similar embeddings for all the nodes.
Note that the graph_info
passed to the constructor of the Keras model, and used as a property of the Keras model object, rather than input data for training or prediction. The model will accept a batch of node_indices
, which are used to lookup the node features and neighbours from the graph_info
.
Let's test instantiating and calling the GNN model. Notice that if you provide N
node indices, the output will be a tensor of shape [N, num_classes]
, regardless of the size of the graph.
Train the GNN model
Note that we use the standard supervised cross-entropy loss to train the model. However, we can add another self-supervised loss term for the generated node embeddings that makes sure that neighbouring nodes in graph have similar representations, while faraway nodes have dissimilar representations.
Let's plot the learning curves
Now we evaluate the GNN model on the test data split. The results may vary depending on the training sample, however the GNN model always outperforms the baseline model in terms of the test accuracy.
Examine the GNN model predictions
Let's add the new instances as nodes to the node_features
, and generate links (citations) to existing nodes.
Now let's update the node_features
and the edges
in the GNN model.
Notice that the probabilities of the expected subjects (to which several citations are added) are higher compared to the baseline model.