Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/graph/gat_node_classification.py
3507 views
1
"""
2
Title: Graph attention network (GAT) for node classification
3
Author: [akensert](https://github.com/akensert)
4
Date created: 2021/09/13
5
Last modified: 2021/12/26
6
Description: An implementation of a Graph Attention Network (GAT) for node classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
[Graph neural networks](https://en.wikipedia.org/wiki/Graph_neural_network)
14
is the preferred neural network architecture for processing data structured as
15
graphs (for example, social networks or molecule structures), yielding
16
better results than fully-connected networks or convolutional networks.
17
18
In this tutorial, we will implement a specific graph neural network known as a
19
[Graph Attention Network](https://arxiv.org/abs/1710.10903) (GAT) to predict labels of
20
scientific papers based on what type of papers cite them (using the
21
[Cora](https://linqs.soe.ucsc.edu/data) dataset).
22
23
### References
24
25
For more information on GAT, see the original paper
26
[Graph Attention Networks](https://arxiv.org/abs/1710.10903) as well as
27
[DGL's Graph Attention Networks](https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html)
28
documentation.
29
"""
30
31
"""
32
### Import packages
33
"""
34
35
import tensorflow as tf
36
from tensorflow import keras
37
from tensorflow.keras import layers
38
import numpy as np
39
import pandas as pd
40
import os
41
import warnings
42
43
warnings.filterwarnings("ignore")
44
pd.set_option("display.max_columns", 6)
45
pd.set_option("display.max_rows", 6)
46
np.random.seed(2)
47
48
"""
49
## Obtain the dataset
50
51
The preparation of the [Cora dataset](https://linqs.soe.ucsc.edu/data) follows that of the
52
[Node classification with Graph Neural Networks](https://keras.io/examples/graph/gnn_citations/)
53
tutorial. Refer to this tutorial for more details on the dataset and exploratory data analysis.
54
In brief, the Cora dataset consists of two files: `cora.cites` which contains *directed links* (citations) between
55
papers; and `cora.content` which contains *features* of the corresponding papers and one
56
of seven labels (the *subject* of the paper).
57
"""
58
59
zip_file = keras.utils.get_file(
60
fname="cora.tgz",
61
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
62
extract=True,
63
)
64
65
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
66
67
citations = pd.read_csv(
68
os.path.join(data_dir, "cora.cites"),
69
sep="\t",
70
header=None,
71
names=["target", "source"],
72
)
73
74
papers = pd.read_csv(
75
os.path.join(data_dir, "cora.content"),
76
sep="\t",
77
header=None,
78
names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
79
)
80
81
class_values = sorted(papers["subject"].unique())
82
class_idx = {name: id for id, name in enumerate(class_values)}
83
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
84
85
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
86
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
87
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
88
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
89
90
print(citations)
91
92
print(papers)
93
94
"""
95
### Split the dataset
96
"""
97
98
# Obtain random indices
99
random_indices = np.random.permutation(range(papers.shape[0]))
100
101
# 50/50 split
102
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
103
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]
104
105
"""
106
### Prepare the graph data
107
"""
108
109
# Obtain paper indices which will be used to gather node states
110
# from the graph later on when training the model
111
train_indices = train_data["paper_id"].to_numpy()
112
test_indices = test_data["paper_id"].to_numpy()
113
114
# Obtain ground truth labels corresponding to each paper_id
115
train_labels = train_data["subject"].to_numpy()
116
test_labels = test_data["subject"].to_numpy()
117
118
# Define graph, namely an edge tensor and a node feature tensor
119
edges = tf.convert_to_tensor(citations[["target", "source"]])
120
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])
121
122
# Print shapes of the graph
123
print("Edges shape:\t\t", edges.shape)
124
print("Node features shape:", node_states.shape)
125
126
"""
127
## Build the model
128
129
GAT takes as input a graph (namely an edge tensor and a node feature tensor) and
130
outputs \[updated\] node states. The node states are, for each target node, neighborhood
131
aggregated information of *N*-hops (where *N* is decided by the number of layers of the
132
GAT). Importantly, in contrast to the
133
[graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN)
134
the GAT makes use of attention mechanisms
135
to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply
136
averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*),
137
GAT first applies normalized attention scores to each source node state and then sums.
138
"""
139
140
"""
141
### (Multi-head) graph attention layer
142
143
The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention`
144
layer is simply a concatenation (or averaging) of multiple graph attention layers
145
(`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer
146
does the following:
147
148
Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`.
149
150
For each target node:
151
152
1. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`,
153
resulting in `e_{ij}` (for all `j`).
154
`||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}`
155
corresponds to a given 1-hop neighbor/source node.
156
2. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores
157
to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1.
158
3. Applies attention scores `e_{norm}_{ij}` to `z_{j}`
159
and adds it to the new target node state `h^{l+1}_{i}`, for all `j`.
160
"""
161
162
163
class GraphAttention(layers.Layer):
164
def __init__(
165
self,
166
units,
167
kernel_initializer="glorot_uniform",
168
kernel_regularizer=None,
169
**kwargs,
170
):
171
super().__init__(**kwargs)
172
self.units = units
173
self.kernel_initializer = keras.initializers.get(kernel_initializer)
174
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
175
176
def build(self, input_shape):
177
self.kernel = self.add_weight(
178
shape=(input_shape[0][-1], self.units),
179
trainable=True,
180
initializer=self.kernel_initializer,
181
regularizer=self.kernel_regularizer,
182
name="kernel",
183
)
184
self.kernel_attention = self.add_weight(
185
shape=(self.units * 2, 1),
186
trainable=True,
187
initializer=self.kernel_initializer,
188
regularizer=self.kernel_regularizer,
189
name="kernel_attention",
190
)
191
self.built = True
192
193
def call(self, inputs):
194
node_states, edges = inputs
195
196
# Linearly transform node states
197
node_states_transformed = tf.matmul(node_states, self.kernel)
198
199
# (1) Compute pair-wise attention scores
200
node_states_expanded = tf.gather(node_states_transformed, edges)
201
node_states_expanded = tf.reshape(
202
node_states_expanded, (tf.shape(edges)[0], -1)
203
)
204
attention_scores = tf.nn.leaky_relu(
205
tf.matmul(node_states_expanded, self.kernel_attention)
206
)
207
attention_scores = tf.squeeze(attention_scores, -1)
208
209
# (2) Normalize attention scores
210
attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
211
attention_scores_sum = tf.math.unsorted_segment_sum(
212
data=attention_scores,
213
segment_ids=edges[:, 0],
214
num_segments=tf.reduce_max(edges[:, 0]) + 1,
215
)
216
attention_scores_sum = tf.repeat(
217
attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
218
)
219
attention_scores_norm = attention_scores / attention_scores_sum
220
221
# (3) Gather node states of neighbors, apply attention scores and aggregate
222
node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
223
out = tf.math.unsorted_segment_sum(
224
data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
225
segment_ids=edges[:, 0],
226
num_segments=tf.shape(node_states)[0],
227
)
228
return out
229
230
231
class MultiHeadGraphAttention(layers.Layer):
232
def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
233
super().__init__(**kwargs)
234
self.num_heads = num_heads
235
self.merge_type = merge_type
236
self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]
237
238
def call(self, inputs):
239
atom_features, pair_indices = inputs
240
241
# Obtain outputs from each attention head
242
outputs = [
243
attention_layer([atom_features, pair_indices])
244
for attention_layer in self.attention_layers
245
]
246
# Concatenate or average the node states from each head
247
if self.merge_type == "concat":
248
outputs = tf.concat(outputs, axis=-1)
249
else:
250
outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
251
# Activate and return node states
252
return tf.nn.relu(outputs)
253
254
255
"""
256
### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods
257
258
Notice, the GAT model operates on the entire graph (namely, `node_states` and
259
`edges`) in all phases (training, validation and testing). Hence, `node_states` and
260
`edges` are passed to the constructor of the `keras.Model` and used as attributes.
261
The difference between the phases are the indices (and labels), which gathers
262
certain outputs (`tf.gather(outputs, indices)`).
263
264
"""
265
266
267
class GraphAttentionNetwork(keras.Model):
268
def __init__(
269
self,
270
node_states,
271
edges,
272
hidden_units,
273
num_heads,
274
num_layers,
275
output_dim,
276
**kwargs,
277
):
278
super().__init__(**kwargs)
279
self.node_states = node_states
280
self.edges = edges
281
self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
282
self.attention_layers = [
283
MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
284
]
285
self.output_layer = layers.Dense(output_dim)
286
287
def call(self, inputs):
288
node_states, edges = inputs
289
x = self.preprocess(node_states)
290
for attention_layer in self.attention_layers:
291
x = attention_layer([x, edges]) + x
292
outputs = self.output_layer(x)
293
return outputs
294
295
def train_step(self, data):
296
indices, labels = data
297
298
with tf.GradientTape() as tape:
299
# Forward pass
300
outputs = self([self.node_states, self.edges])
301
# Compute loss
302
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
303
# Compute gradients
304
grads = tape.gradient(loss, self.trainable_weights)
305
# Apply gradients (update weights)
306
optimizer.apply_gradients(zip(grads, self.trainable_weights))
307
# Update metric(s)
308
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
309
310
return {m.name: m.result() for m in self.metrics}
311
312
def predict_step(self, data):
313
indices = data
314
# Forward pass
315
outputs = self([self.node_states, self.edges])
316
# Compute probabilities
317
return tf.nn.softmax(tf.gather(outputs, indices))
318
319
def test_step(self, data):
320
indices, labels = data
321
# Forward pass
322
outputs = self([self.node_states, self.edges])
323
# Compute loss
324
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
325
# Update metric(s)
326
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
327
328
return {m.name: m.result() for m in self.metrics}
329
330
331
"""
332
### Train and evaluate
333
"""
334
335
# Define hyper-parameters
336
HIDDEN_UNITS = 100
337
NUM_HEADS = 8
338
NUM_LAYERS = 3
339
OUTPUT_DIM = len(class_values)
340
341
NUM_EPOCHS = 100
342
BATCH_SIZE = 256
343
VALIDATION_SPLIT = 0.1
344
LEARNING_RATE = 3e-1
345
MOMENTUM = 0.9
346
347
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
348
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
349
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
350
early_stopping = keras.callbacks.EarlyStopping(
351
monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
352
)
353
354
# Build model
355
gat_model = GraphAttentionNetwork(
356
node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
357
)
358
359
# Compile model
360
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])
361
362
gat_model.fit(
363
x=train_indices,
364
y=train_labels,
365
validation_split=VALIDATION_SPLIT,
366
batch_size=BATCH_SIZE,
367
epochs=NUM_EPOCHS,
368
callbacks=[early_stopping],
369
verbose=2,
370
)
371
372
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)
373
374
print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")
375
376
"""
377
### Predict (probabilities)
378
"""
379
test_probs = gat_model.predict(x=test_indices)
380
381
mapping = {v: k for (k, v) in class_idx.items()}
382
383
for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
384
print(f"Example {i+1}: {mapping[label]}")
385
for j, c in zip(probs, class_idx.keys()):
386
print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
387
print("---" * 20)
388
389
"""
390
## Conclusions
391
392
The results look OK! The GAT model seems to correctly predict the subjects of the papers,
393
based on what they cite, about 80% of the time. Further improvements could be
394
made by fine-tuning the hyper-parameters of the GAT. For instance, try changing the number of layers,
395
the number of hidden units, or the optimizer/learning rate; add regularization (e.g., dropout);
396
or modify the preprocessing step. We could also try to implement *self-loops*
397
(i.e., paper X cites paper X) and/or make the graph *undirected*.
398
"""
399
400