Path: blob/master/examples/graph/gat_node_classification.py
3507 views
"""1Title: Graph attention network (GAT) for node classification2Author: [akensert](https://github.com/akensert)3Date created: 2021/09/134Last modified: 2021/12/265Description: An implementation of a Graph Attention Network (GAT) for node classification.6Accelerator: GPU7"""89"""10## Introduction1112[Graph neural networks](https://en.wikipedia.org/wiki/Graph_neural_network)13is the preferred neural network architecture for processing data structured as14graphs (for example, social networks or molecule structures), yielding15better results than fully-connected networks or convolutional networks.1617In this tutorial, we will implement a specific graph neural network known as a18[Graph Attention Network](https://arxiv.org/abs/1710.10903) (GAT) to predict labels of19scientific papers based on what type of papers cite them (using the20[Cora](https://linqs.soe.ucsc.edu/data) dataset).2122### References2324For more information on GAT, see the original paper25[Graph Attention Networks](https://arxiv.org/abs/1710.10903) as well as26[DGL's Graph Attention Networks](https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html)27documentation.28"""2930"""31### Import packages32"""3334import tensorflow as tf35from tensorflow import keras36from tensorflow.keras import layers37import numpy as np38import pandas as pd39import os40import warnings4142warnings.filterwarnings("ignore")43pd.set_option("display.max_columns", 6)44pd.set_option("display.max_rows", 6)45np.random.seed(2)4647"""48## Obtain the dataset4950The preparation of the [Cora dataset](https://linqs.soe.ucsc.edu/data) follows that of the51[Node classification with Graph Neural Networks](https://keras.io/examples/graph/gnn_citations/)52tutorial. Refer to this tutorial for more details on the dataset and exploratory data analysis.53In brief, the Cora dataset consists of two files: `cora.cites` which contains *directed links* (citations) between54papers; and `cora.content` which contains *features* of the corresponding papers and one55of seven labels (the *subject* of the paper).56"""5758zip_file = keras.utils.get_file(59fname="cora.tgz",60origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",61extract=True,62)6364data_dir = os.path.join(os.path.dirname(zip_file), "cora")6566citations = pd.read_csv(67os.path.join(data_dir, "cora.cites"),68sep="\t",69header=None,70names=["target", "source"],71)7273papers = pd.read_csv(74os.path.join(data_dir, "cora.content"),75sep="\t",76header=None,77names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],78)7980class_values = sorted(papers["subject"].unique())81class_idx = {name: id for id, name in enumerate(class_values)}82paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}8384papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])85citations["source"] = citations["source"].apply(lambda name: paper_idx[name])86citations["target"] = citations["target"].apply(lambda name: paper_idx[name])87papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])8889print(citations)9091print(papers)9293"""94### Split the dataset95"""9697# Obtain random indices98random_indices = np.random.permutation(range(papers.shape[0]))99100# 50/50 split101train_data = papers.iloc[random_indices[: len(random_indices) // 2]]102test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]103104"""105### Prepare the graph data106"""107108# Obtain paper indices which will be used to gather node states109# from the graph later on when training the model110train_indices = train_data["paper_id"].to_numpy()111test_indices = test_data["paper_id"].to_numpy()112113# Obtain ground truth labels corresponding to each paper_id114train_labels = train_data["subject"].to_numpy()115test_labels = test_data["subject"].to_numpy()116117# Define graph, namely an edge tensor and a node feature tensor118edges = tf.convert_to_tensor(citations[["target", "source"]])119node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])120121# Print shapes of the graph122print("Edges shape:\t\t", edges.shape)123print("Node features shape:", node_states.shape)124125"""126## Build the model127128GAT takes as input a graph (namely an edge tensor and a node feature tensor) and129outputs \[updated\] node states. The node states are, for each target node, neighborhood130aggregated information of *N*-hops (where *N* is decided by the number of layers of the131GAT). Importantly, in contrast to the132[graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN)133the GAT makes use of attention mechanisms134to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply135averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*),136GAT first applies normalized attention scores to each source node state and then sums.137"""138139"""140### (Multi-head) graph attention layer141142The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention`143layer is simply a concatenation (or averaging) of multiple graph attention layers144(`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer145does the following:146147Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`.148149For each target node:1501511. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`,152resulting in `e_{ij}` (for all `j`).153`||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}`154corresponds to a given 1-hop neighbor/source node.1552. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores156to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1.1573. Applies attention scores `e_{norm}_{ij}` to `z_{j}`158and adds it to the new target node state `h^{l+1}_{i}`, for all `j`.159"""160161162class GraphAttention(layers.Layer):163def __init__(164self,165units,166kernel_initializer="glorot_uniform",167kernel_regularizer=None,168**kwargs,169):170super().__init__(**kwargs)171self.units = units172self.kernel_initializer = keras.initializers.get(kernel_initializer)173self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)174175def build(self, input_shape):176self.kernel = self.add_weight(177shape=(input_shape[0][-1], self.units),178trainable=True,179initializer=self.kernel_initializer,180regularizer=self.kernel_regularizer,181name="kernel",182)183self.kernel_attention = self.add_weight(184shape=(self.units * 2, 1),185trainable=True,186initializer=self.kernel_initializer,187regularizer=self.kernel_regularizer,188name="kernel_attention",189)190self.built = True191192def call(self, inputs):193node_states, edges = inputs194195# Linearly transform node states196node_states_transformed = tf.matmul(node_states, self.kernel)197198# (1) Compute pair-wise attention scores199node_states_expanded = tf.gather(node_states_transformed, edges)200node_states_expanded = tf.reshape(201node_states_expanded, (tf.shape(edges)[0], -1)202)203attention_scores = tf.nn.leaky_relu(204tf.matmul(node_states_expanded, self.kernel_attention)205)206attention_scores = tf.squeeze(attention_scores, -1)207208# (2) Normalize attention scores209attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))210attention_scores_sum = tf.math.unsorted_segment_sum(211data=attention_scores,212segment_ids=edges[:, 0],213num_segments=tf.reduce_max(edges[:, 0]) + 1,214)215attention_scores_sum = tf.repeat(216attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))217)218attention_scores_norm = attention_scores / attention_scores_sum219220# (3) Gather node states of neighbors, apply attention scores and aggregate221node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])222out = tf.math.unsorted_segment_sum(223data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],224segment_ids=edges[:, 0],225num_segments=tf.shape(node_states)[0],226)227return out228229230class MultiHeadGraphAttention(layers.Layer):231def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):232super().__init__(**kwargs)233self.num_heads = num_heads234self.merge_type = merge_type235self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]236237def call(self, inputs):238atom_features, pair_indices = inputs239240# Obtain outputs from each attention head241outputs = [242attention_layer([atom_features, pair_indices])243for attention_layer in self.attention_layers244]245# Concatenate or average the node states from each head246if self.merge_type == "concat":247outputs = tf.concat(outputs, axis=-1)248else:249outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)250# Activate and return node states251return tf.nn.relu(outputs)252253254"""255### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods256257Notice, the GAT model operates on the entire graph (namely, `node_states` and258`edges`) in all phases (training, validation and testing). Hence, `node_states` and259`edges` are passed to the constructor of the `keras.Model` and used as attributes.260The difference between the phases are the indices (and labels), which gathers261certain outputs (`tf.gather(outputs, indices)`).262263"""264265266class GraphAttentionNetwork(keras.Model):267def __init__(268self,269node_states,270edges,271hidden_units,272num_heads,273num_layers,274output_dim,275**kwargs,276):277super().__init__(**kwargs)278self.node_states = node_states279self.edges = edges280self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")281self.attention_layers = [282MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)283]284self.output_layer = layers.Dense(output_dim)285286def call(self, inputs):287node_states, edges = inputs288x = self.preprocess(node_states)289for attention_layer in self.attention_layers:290x = attention_layer([x, edges]) + x291outputs = self.output_layer(x)292return outputs293294def train_step(self, data):295indices, labels = data296297with tf.GradientTape() as tape:298# Forward pass299outputs = self([self.node_states, self.edges])300# Compute loss301loss = self.compiled_loss(labels, tf.gather(outputs, indices))302# Compute gradients303grads = tape.gradient(loss, self.trainable_weights)304# Apply gradients (update weights)305optimizer.apply_gradients(zip(grads, self.trainable_weights))306# Update metric(s)307self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))308309return {m.name: m.result() for m in self.metrics}310311def predict_step(self, data):312indices = data313# Forward pass314outputs = self([self.node_states, self.edges])315# Compute probabilities316return tf.nn.softmax(tf.gather(outputs, indices))317318def test_step(self, data):319indices, labels = data320# Forward pass321outputs = self([self.node_states, self.edges])322# Compute loss323loss = self.compiled_loss(labels, tf.gather(outputs, indices))324# Update metric(s)325self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))326327return {m.name: m.result() for m in self.metrics}328329330"""331### Train and evaluate332"""333334# Define hyper-parameters335HIDDEN_UNITS = 100336NUM_HEADS = 8337NUM_LAYERS = 3338OUTPUT_DIM = len(class_values)339340NUM_EPOCHS = 100341BATCH_SIZE = 256342VALIDATION_SPLIT = 0.1343LEARNING_RATE = 3e-1344MOMENTUM = 0.9345346loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)347optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)348accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")349early_stopping = keras.callbacks.EarlyStopping(350monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True351)352353# Build model354gat_model = GraphAttentionNetwork(355node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM356)357358# Compile model359gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])360361gat_model.fit(362x=train_indices,363y=train_labels,364validation_split=VALIDATION_SPLIT,365batch_size=BATCH_SIZE,366epochs=NUM_EPOCHS,367callbacks=[early_stopping],368verbose=2,369)370371_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)372373print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")374375"""376### Predict (probabilities)377"""378test_probs = gat_model.predict(x=test_indices)379380mapping = {v: k for (k, v) in class_idx.items()}381382for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):383print(f"Example {i+1}: {mapping[label]}")384for j, c in zip(probs, class_idx.keys()):385print(f"\tProbability of {c: <24} = {j*100:7.3f}%")386print("---" * 20)387388"""389## Conclusions390391The results look OK! The GAT model seems to correctly predict the subjects of the papers,392based on what they cite, about 80% of the time. Further improvements could be393made by fine-tuning the hyper-parameters of the GAT. For instance, try changing the number of layers,394the number of hidden units, or the optimizer/learning rate; add regularization (e.g., dropout);395or modify the preprocessing step. We could also try to implement *self-loops*396(i.e., paper X cites paper X) and/or make the graph *undirected*.397"""398399400