Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/graph/gnn_citations.py
3507 views
1
"""
2
Title: Node Classification with Graph Neural Networks
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2021/05/30
5
Last modified: 2021/05/30
6
Description: Implementing a graph neural network model for predicting the topic of a paper given its citations.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Many datasets in various machine learning (ML) applications have structural relationships
14
between their entities, which can be represented as graphs. Such application includes
15
social and communication networks analysis, traffic prediction, and fraud detection.
16
[Graph representation Learning](https://www.cs.mcgill.ca/~wlh/grl_book/)
17
aims to build and train models for graph datasets to be used for a variety of ML tasks.
18
19
This example demonstrate a simple implementation of a [Graph Neural Network](https://arxiv.org/pdf/1901.00596.pdf)
20
(GNN) model. The model is used for a node prediction task on the [Cora dataset](https://relational.fit.cvut.cz/dataset/CORA)
21
to predict the subject of a paper given its words and citations network.
22
23
Note that, **we implement a Graph Convolution Layer from scratch** to provide better
24
understanding of how they work. However, there is a number of specialized TensorFlow-based
25
libraries that provide rich GNN APIs, such as [Spectral](https://graphneural.network/),
26
[StellarGraph](https://stellargraph.readthedocs.io/en/stable/README.html), and
27
[GraphNets](https://github.com/deepmind/graph_nets).
28
"""
29
30
"""
31
## Setup
32
"""
33
34
import os
35
import pandas as pd
36
import numpy as np
37
import networkx as nx
38
import matplotlib.pyplot as plt
39
import tensorflow as tf
40
from tensorflow import keras
41
from tensorflow.keras import layers
42
43
"""
44
## Prepare the Dataset
45
46
The Cora dataset consists of 2,708 scientific papers classified into one of seven classes.
47
The citation network consists of 5,429 links. Each paper has a binary word vector of size
48
1,433, indicating the presence of a corresponding word.
49
50
### Download the dataset
51
52
The dataset has two tap-separated files: `cora.cites` and `cora.content`.
53
54
1. The `cora.cites` includes the citation records with two columns:
55
`cited_paper_id` (target) and `citing_paper_id` (source).
56
2. The `cora.content` includes the paper content records with 1,435 columns:
57
`paper_id`, `subject`, and 1,433 binary features.
58
59
Let's download the dataset.
60
"""
61
62
zip_file = keras.utils.get_file(
63
fname="cora.tgz",
64
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
65
extract=True,
66
)
67
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
68
"""
69
### Process and visualize the dataset
70
71
Then we load the citations data into a Pandas DataFrame.
72
"""
73
74
citations = pd.read_csv(
75
os.path.join(data_dir, "cora.cites"),
76
sep="\t",
77
header=None,
78
names=["target", "source"],
79
)
80
print("Citations shape:", citations.shape)
81
82
"""
83
Now we display a sample of the `citations` DataFrame.
84
The `target` column includes the paper ids cited by the paper ids in the `source` column.
85
"""
86
87
citations.sample(frac=1).head()
88
89
"""
90
Now let's load the papers data into a Pandas DataFrame.
91
"""
92
93
column_names = ["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"]
94
papers = pd.read_csv(
95
os.path.join(data_dir, "cora.content"),
96
sep="\t",
97
header=None,
98
names=column_names,
99
)
100
print("Papers shape:", papers.shape)
101
102
"""
103
Now we display a sample of the `papers` DataFrame. The DataFrame includes the `paper_id`
104
and the `subject` columns, as well as 1,433 binary column representing whether a term exists
105
in the paper or not.
106
"""
107
108
print(papers.sample(5).T)
109
110
"""
111
Let's display the count of the papers in each subject.
112
"""
113
114
print(papers.subject.value_counts())
115
116
"""
117
We convert the paper ids and the subjects into zero-based indices.
118
"""
119
120
class_values = sorted(papers["subject"].unique())
121
class_idx = {name: id for id, name in enumerate(class_values)}
122
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
123
124
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
125
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
126
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
127
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
128
129
"""
130
Now let's visualize the citation graph. Each node in the graph represents a paper,
131
and the color of the node corresponds to its subject. Note that we only show a sample of
132
the papers in the dataset.
133
"""
134
135
plt.figure(figsize=(10, 10))
136
colors = papers["subject"].tolist()
137
cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
138
subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
139
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)
140
141
142
"""
143
### Split the dataset into stratified train and test sets
144
"""
145
146
train_data, test_data = [], []
147
148
for _, group_data in papers.groupby("subject"):
149
# Select around 50% of the dataset for training.
150
random_selection = np.random.rand(len(group_data.index)) <= 0.5
151
train_data.append(group_data[random_selection])
152
test_data.append(group_data[~random_selection])
153
154
train_data = pd.concat(train_data).sample(frac=1)
155
test_data = pd.concat(test_data).sample(frac=1)
156
157
print("Train data shape:", train_data.shape)
158
print("Test data shape:", test_data.shape)
159
160
"""
161
## Implement Train and Evaluate Experiment
162
"""
163
164
hidden_units = [32, 32]
165
learning_rate = 0.01
166
dropout_rate = 0.5
167
num_epochs = 300
168
batch_size = 256
169
170
"""
171
This function compiles and trains an input model using the given training data.
172
"""
173
174
175
def run_experiment(model, x_train, y_train):
176
# Compile the model.
177
model.compile(
178
optimizer=keras.optimizers.Adam(learning_rate),
179
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
180
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
181
)
182
# Create an early stopping callback.
183
early_stopping = keras.callbacks.EarlyStopping(
184
monitor="val_acc", patience=50, restore_best_weights=True
185
)
186
# Fit the model.
187
history = model.fit(
188
x=x_train,
189
y=y_train,
190
epochs=num_epochs,
191
batch_size=batch_size,
192
validation_split=0.15,
193
callbacks=[early_stopping],
194
)
195
196
return history
197
198
199
"""
200
This function displays the loss and accuracy curves of the model during training.
201
"""
202
203
204
def display_learning_curves(history):
205
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
206
207
ax1.plot(history.history["loss"])
208
ax1.plot(history.history["val_loss"])
209
ax1.legend(["train", "test"], loc="upper right")
210
ax1.set_xlabel("Epochs")
211
ax1.set_ylabel("Loss")
212
213
ax2.plot(history.history["acc"])
214
ax2.plot(history.history["val_acc"])
215
ax2.legend(["train", "test"], loc="upper right")
216
ax2.set_xlabel("Epochs")
217
ax2.set_ylabel("Accuracy")
218
plt.show()
219
220
221
"""
222
## Implement Feedforward Network (FFN) Module
223
224
We will use this module in the baseline and the GNN models.
225
"""
226
227
228
def create_ffn(hidden_units, dropout_rate, name=None):
229
fnn_layers = []
230
231
for units in hidden_units:
232
fnn_layers.append(layers.BatchNormalization())
233
fnn_layers.append(layers.Dropout(dropout_rate))
234
fnn_layers.append(layers.Dense(units, activation=tf.nn.gelu))
235
236
return keras.Sequential(fnn_layers, name=name)
237
238
239
"""
240
## Build a Baseline Neural Network Model
241
242
### Prepare the data for the baseline model
243
"""
244
245
feature_names = list(set(papers.columns) - {"paper_id", "subject"})
246
num_features = len(feature_names)
247
num_classes = len(class_idx)
248
249
# Create train and test features as a numpy array.
250
x_train = train_data[feature_names].to_numpy()
251
x_test = test_data[feature_names].to_numpy()
252
# Create train and test targets as a numpy array.
253
y_train = train_data["subject"]
254
y_test = test_data["subject"]
255
256
"""
257
### Implement a baseline classifier
258
259
We add five FFN blocks with skip connections, so that we generate a baseline model with
260
roughly the same number of parameters as the GNN models to be built later.
261
"""
262
263
264
def create_baseline_model(hidden_units, num_classes, dropout_rate=0.2):
265
inputs = layers.Input(shape=(num_features,), name="input_features")
266
x = create_ffn(hidden_units, dropout_rate, name=f"ffn_block1")(inputs)
267
for block_idx in range(4):
268
# Create an FFN block.
269
x1 = create_ffn(hidden_units, dropout_rate, name=f"ffn_block{block_idx + 2}")(x)
270
# Add skip connection.
271
x = layers.Add(name=f"skip_connection{block_idx + 2}")([x, x1])
272
# Compute logits.
273
logits = layers.Dense(num_classes, name="logits")(x)
274
# Create the model.
275
return keras.Model(inputs=inputs, outputs=logits, name="baseline")
276
277
278
baseline_model = create_baseline_model(hidden_units, num_classes, dropout_rate)
279
baseline_model.summary()
280
281
"""
282
### Train the baseline classifier
283
"""
284
285
history = run_experiment(baseline_model, x_train, y_train)
286
287
"""
288
Let's plot the learning curves.
289
"""
290
291
display_learning_curves(history)
292
293
"""
294
Now we evaluate the baseline model on the test data split.
295
"""
296
297
_, test_accuracy = baseline_model.evaluate(x=x_test, y=y_test, verbose=0)
298
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
299
300
"""
301
### Examine the baseline model predictions
302
303
Let's create new data instances by randomly generating binary word vectors with respect to
304
the word presence probabilities.
305
"""
306
307
308
def generate_random_instances(num_instances):
309
token_probability = x_train.mean(axis=0)
310
instances = []
311
for _ in range(num_instances):
312
probabilities = np.random.uniform(size=len(token_probability))
313
instance = (probabilities <= token_probability).astype(int)
314
instances.append(instance)
315
316
return np.array(instances)
317
318
319
def display_class_probabilities(probabilities):
320
for instance_idx, probs in enumerate(probabilities):
321
print(f"Instance {instance_idx + 1}:")
322
for class_idx, prob in enumerate(probs):
323
print(f"- {class_values[class_idx]}: {round(prob * 100, 2)}%")
324
325
326
"""
327
Now we show the baseline model predictions given these randomly generated instances.
328
"""
329
330
new_instances = generate_random_instances(num_classes)
331
logits = baseline_model.predict(new_instances)
332
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
333
display_class_probabilities(probabilities)
334
335
"""
336
## Build a Graph Neural Network Model
337
338
### Prepare the data for the graph model
339
340
Preparing and loading the graphs data into the model for training is the most challenging
341
part in GNN models, which is addressed in different ways by the specialised libraries.
342
In this example, we show a simple approach for preparing and using graph data that is suitable
343
if your dataset consists of a single graph that fits entirely in memory.
344
345
The graph data is represented by the `graph_info` tuple, which consists of the following
346
three elements:
347
348
1. `node_features`: This is a `[num_nodes, num_features]` NumPy array that includes the
349
node features. In this dataset, the nodes are the papers, and the `node_features` are the
350
word-presence binary vectors of each paper.
351
2. `edges`: This is `[num_edges, num_edges]` NumPy array representing a sparse
352
[adjacency matrix](https://en.wikipedia.org/wiki/Adjacency_matrix#:~:text=In%20graph%20theory%20and%20computer,with%20zeros%20on%20its%20diagonal.)
353
of the links between the nodes. In this example, the links are the citations between the papers.
354
3. `edge_weights` (optional): This is a `[num_edges]` NumPy array that includes the edge weights, which *quantify*
355
the relationships between nodes in the graph. In this example, there are no weights for the paper citations.
356
"""
357
358
# Create an edges array (sparse adjacency matrix) of shape [2, num_edges].
359
edges = citations[["source", "target"]].to_numpy().T
360
# Create an edge weights array of ones.
361
edge_weights = tf.ones(shape=edges.shape[1])
362
# Create a node features array of shape [num_nodes, num_features].
363
node_features = tf.cast(
364
papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32
365
)
366
# Create graph info tuple with node_features, edges, and edge_weights.
367
graph_info = (node_features, edges, edge_weights)
368
369
print("Edges shape:", edges.shape)
370
print("Nodes shape:", node_features.shape)
371
372
"""
373
### Implement a graph convolution layer
374
375
We implement a graph convolution module as a [Keras Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer?version=nightly).
376
Our `GraphConvLayer` performs the following steps:
377
378
1. **Prepare**: The input node representations are processed using a FFN to produce a *message*. You can simplify
379
the processing by only applying linear transformation to the representations.
380
2. **Aggregate**: The messages of the neighbours of each node are aggregated with
381
respect to the `edge_weights` using a *permutation invariant* pooling operation, such as *sum*, *mean*, and *max*,
382
to prepare a single aggregated message for each node. See, for example, [tf.math.unsorted_segment_sum](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum)
383
APIs used to aggregate neighbour messages.
384
3. **Update**: The `node_repesentations` and `aggregated_messages`—both of shape `[num_nodes, representation_dim]`—
385
are combined and processed to produce the new state of the node representations (node embeddings).
386
If `combination_type` is `gru`, the `node_repesentations` and `aggregated_messages` are stacked to create a sequence,
387
then processed by a GRU layer. Otherwise, the `node_repesentations` and `aggregated_messages` are added
388
or concatenated, then processed using a FFN.
389
390
391
The technique implemented use ideas from [Graph Convolutional Networks](https://arxiv.org/abs/1609.02907),
392
[GraphSage](https://arxiv.org/abs/1706.02216), [Graph Isomorphism Network](https://arxiv.org/abs/1810.00826),
393
[Simple Graph Networks](https://arxiv.org/abs/1902.07153), and
394
[Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493).
395
Two other key techniques that are not covered are [Graph Attention Networks](https://arxiv.org/abs/1710.10903)
396
and [Message Passing Neural Networks](https://arxiv.org/abs/1704.01212).
397
"""
398
399
400
def create_gru(hidden_units, dropout_rate):
401
inputs = keras.layers.Input(shape=(2, hidden_units[0]))
402
x = inputs
403
for units in hidden_units:
404
x = layers.GRU(
405
units=units,
406
activation="tanh",
407
recurrent_activation="sigmoid",
408
return_sequences=True,
409
dropout=dropout_rate,
410
return_state=False,
411
recurrent_dropout=dropout_rate,
412
)(x)
413
return keras.Model(inputs=inputs, outputs=x)
414
415
416
class GraphConvLayer(layers.Layer):
417
def __init__(
418
self,
419
hidden_units,
420
dropout_rate=0.2,
421
aggregation_type="mean",
422
combination_type="concat",
423
normalize=False,
424
*args,
425
**kwargs,
426
):
427
super().__init__(*args, **kwargs)
428
429
self.aggregation_type = aggregation_type
430
self.combination_type = combination_type
431
self.normalize = normalize
432
433
self.ffn_prepare = create_ffn(hidden_units, dropout_rate)
434
if self.combination_type == "gru":
435
self.update_fn = create_gru(hidden_units, dropout_rate)
436
else:
437
self.update_fn = create_ffn(hidden_units, dropout_rate)
438
439
def prepare(self, node_repesentations, weights=None):
440
# node_repesentations shape is [num_edges, embedding_dim].
441
messages = self.ffn_prepare(node_repesentations)
442
if weights is not None:
443
messages = messages * tf.expand_dims(weights, -1)
444
return messages
445
446
def aggregate(self, node_indices, neighbour_messages, node_repesentations):
447
# node_indices shape is [num_edges].
448
# neighbour_messages shape: [num_edges, representation_dim].
449
# node_repesentations shape is [num_nodes, representation_dim]
450
num_nodes = node_repesentations.shape[0]
451
if self.aggregation_type == "sum":
452
aggregated_message = tf.math.unsorted_segment_sum(
453
neighbour_messages, node_indices, num_segments=num_nodes
454
)
455
elif self.aggregation_type == "mean":
456
aggregated_message = tf.math.unsorted_segment_mean(
457
neighbour_messages, node_indices, num_segments=num_nodes
458
)
459
elif self.aggregation_type == "max":
460
aggregated_message = tf.math.unsorted_segment_max(
461
neighbour_messages, node_indices, num_segments=num_nodes
462
)
463
else:
464
raise ValueError(f"Invalid aggregation type: {self.aggregation_type}.")
465
466
return aggregated_message
467
468
def update(self, node_repesentations, aggregated_messages):
469
# node_repesentations shape is [num_nodes, representation_dim].
470
# aggregated_messages shape is [num_nodes, representation_dim].
471
if self.combination_type == "gru":
472
# Create a sequence of two elements for the GRU layer.
473
h = tf.stack([node_repesentations, aggregated_messages], axis=1)
474
elif self.combination_type == "concat":
475
# Concatenate the node_repesentations and aggregated_messages.
476
h = tf.concat([node_repesentations, aggregated_messages], axis=1)
477
elif self.combination_type == "add":
478
# Add node_repesentations and aggregated_messages.
479
h = node_repesentations + aggregated_messages
480
else:
481
raise ValueError(f"Invalid combination type: {self.combination_type}.")
482
483
# Apply the processing function.
484
node_embeddings = self.update_fn(h)
485
if self.combination_type == "gru":
486
node_embeddings = tf.unstack(node_embeddings, axis=1)[-1]
487
488
if self.normalize:
489
node_embeddings = tf.nn.l2_normalize(node_embeddings, axis=-1)
490
return node_embeddings
491
492
def call(self, inputs):
493
"""Process the inputs to produce the node_embeddings.
494
495
inputs: a tuple of three elements: node_repesentations, edges, edge_weights.
496
Returns: node_embeddings of shape [num_nodes, representation_dim].
497
"""
498
499
node_repesentations, edges, edge_weights = inputs
500
# Get node_indices (source) and neighbour_indices (target) from edges.
501
node_indices, neighbour_indices = edges[0], edges[1]
502
# neighbour_repesentations shape is [num_edges, representation_dim].
503
neighbour_repesentations = tf.gather(node_repesentations, neighbour_indices)
504
505
# Prepare the messages of the neighbours.
506
neighbour_messages = self.prepare(neighbour_repesentations, edge_weights)
507
# Aggregate the neighbour messages.
508
aggregated_messages = self.aggregate(
509
node_indices, neighbour_messages, node_repesentations
510
)
511
# Update the node embedding with the neighbour messages.
512
return self.update(node_repesentations, aggregated_messages)
513
514
515
"""
516
### Implement a graph neural network node classifier
517
518
The GNN classification model follows the [Design Space for Graph Neural Networks](https://arxiv.org/abs/2011.08843) approach,
519
as follows:
520
521
1. Apply preprocessing using FFN to the node features to generate initial node representations.
522
2. Apply one or more graph convolutional layer, with skip connections, to the node representation
523
to produce node embeddings.
524
3. Apply post-processing using FFN to the node embeddings to generate the final node embeddings.
525
4. Feed the node embeddings in a Softmax layer to predict the node class.
526
527
Each graph convolutional layer added captures information from a further level of neighbours.
528
However, adding many graph convolutional layer can cause oversmoothing, where the model
529
produces similar embeddings for all the nodes.
530
531
Note that the `graph_info` passed to the constructor of the Keras model, and used as a *property*
532
of the Keras model object, rather than input data for training or prediction.
533
The model will accept a **batch** of `node_indices`, which are used to lookup the
534
node features and neighbours from the `graph_info`.
535
"""
536
537
538
class GNNNodeClassifier(tf.keras.Model):
539
def __init__(
540
self,
541
graph_info,
542
num_classes,
543
hidden_units,
544
aggregation_type="sum",
545
combination_type="concat",
546
dropout_rate=0.2,
547
normalize=True,
548
*args,
549
**kwargs,
550
):
551
super().__init__(*args, **kwargs)
552
553
# Unpack graph_info to three elements: node_features, edges, and edge_weight.
554
node_features, edges, edge_weights = graph_info
555
self.node_features = node_features
556
self.edges = edges
557
self.edge_weights = edge_weights
558
# Set edge_weights to ones if not provided.
559
if self.edge_weights is None:
560
self.edge_weights = tf.ones(shape=edges.shape[1])
561
# Scale edge_weights to sum to 1.
562
self.edge_weights = self.edge_weights / tf.math.reduce_sum(self.edge_weights)
563
564
# Create a process layer.
565
self.preprocess = create_ffn(hidden_units, dropout_rate, name="preprocess")
566
# Create the first GraphConv layer.
567
self.conv1 = GraphConvLayer(
568
hidden_units,
569
dropout_rate,
570
aggregation_type,
571
combination_type,
572
normalize,
573
name="graph_conv1",
574
)
575
# Create the second GraphConv layer.
576
self.conv2 = GraphConvLayer(
577
hidden_units,
578
dropout_rate,
579
aggregation_type,
580
combination_type,
581
normalize,
582
name="graph_conv2",
583
)
584
# Create a postprocess layer.
585
self.postprocess = create_ffn(hidden_units, dropout_rate, name="postprocess")
586
# Create a compute logits layer.
587
self.compute_logits = layers.Dense(units=num_classes, name="logits")
588
589
def call(self, input_node_indices):
590
# Preprocess the node_features to produce node representations.
591
x = self.preprocess(self.node_features)
592
# Apply the first graph conv layer.
593
x1 = self.conv1((x, self.edges, self.edge_weights))
594
# Skip connection.
595
x = x1 + x
596
# Apply the second graph conv layer.
597
x2 = self.conv2((x, self.edges, self.edge_weights))
598
# Skip connection.
599
x = x2 + x
600
# Postprocess node embedding.
601
x = self.postprocess(x)
602
# Fetch node embeddings for the input node_indices.
603
node_embeddings = tf.gather(x, input_node_indices)
604
# Compute logits
605
return self.compute_logits(node_embeddings)
606
607
608
"""
609
Let's test instantiating and calling the GNN model.
610
Notice that if you provide `N` node indices, the output will be a tensor of shape `[N, num_classes]`,
611
regardless of the size of the graph.
612
"""
613
614
gnn_model = GNNNodeClassifier(
615
graph_info=graph_info,
616
num_classes=num_classes,
617
hidden_units=hidden_units,
618
dropout_rate=dropout_rate,
619
name="gnn_model",
620
)
621
622
print("GNN output shape:", gnn_model([1, 10, 100]))
623
624
gnn_model.summary()
625
626
"""
627
### Train the GNN model
628
629
Note that we use the standard *supervised* cross-entropy loss to train the model.
630
However, we can add another *self-supervised* loss term for the generated node embeddings
631
that makes sure that neighbouring nodes in graph have similar representations, while faraway
632
nodes have dissimilar representations.
633
"""
634
635
x_train = train_data.paper_id.to_numpy()
636
history = run_experiment(gnn_model, x_train, y_train)
637
638
"""
639
Let's plot the learning curves
640
"""
641
642
display_learning_curves(history)
643
644
"""
645
Now we evaluate the GNN model on the test data split.
646
The results may vary depending on the training sample, however the GNN model always outperforms
647
the baseline model in terms of the test accuracy.
648
"""
649
650
x_test = test_data.paper_id.to_numpy()
651
_, test_accuracy = gnn_model.evaluate(x=x_test, y=y_test, verbose=0)
652
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
653
654
"""
655
### Examine the GNN model predictions
656
657
Let's add the new instances as nodes to the `node_features`, and generate links
658
(citations) to existing nodes.
659
"""
660
661
# First we add the N new_instances as nodes to the graph
662
# by appending the new_instance to node_features.
663
num_nodes = node_features.shape[0]
664
new_node_features = np.concatenate([node_features, new_instances])
665
# Second we add the M edges (citations) from each new node to a set
666
# of existing nodes in a particular subject
667
new_node_indices = [i + num_nodes for i in range(num_classes)]
668
new_citations = []
669
for subject_idx, group in papers.groupby("subject"):
670
subject_papers = list(group.paper_id)
671
# Select random x papers specific subject.
672
selected_paper_indices1 = np.random.choice(subject_papers, 5)
673
# Select random y papers from any subject (where y < x).
674
selected_paper_indices2 = np.random.choice(list(papers.paper_id), 2)
675
# Merge the selected paper indices.
676
selected_paper_indices = np.concatenate(
677
[selected_paper_indices1, selected_paper_indices2], axis=0
678
)
679
# Create edges between a citing paper idx and the selected cited papers.
680
citing_paper_indx = new_node_indices[subject_idx]
681
for cited_paper_idx in selected_paper_indices:
682
new_citations.append([citing_paper_indx, cited_paper_idx])
683
684
new_citations = np.array(new_citations).T
685
new_edges = np.concatenate([edges, new_citations], axis=1)
686
687
"""
688
Now let's update the `node_features` and the `edges` in the GNN model.
689
"""
690
691
print("Original node_features shape:", gnn_model.node_features.shape)
692
print("Original edges shape:", gnn_model.edges.shape)
693
gnn_model.node_features = new_node_features
694
gnn_model.edges = new_edges
695
gnn_model.edge_weights = tf.ones(shape=new_edges.shape[1])
696
print("New node_features shape:", gnn_model.node_features.shape)
697
print("New edges shape:", gnn_model.edges.shape)
698
699
logits = gnn_model.predict(tf.convert_to_tensor(new_node_indices))
700
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
701
display_class_probabilities(probabilities)
702
703
"""
704
Notice that the probabilities of the expected subjects
705
(to which several citations are added) are higher compared to the baseline model.
706
"""
707
708