Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/graph/mpnn-molecular-graphs.py
3507 views
1
"""
2
Title: Message-passing neural network (MPNN) for molecular property prediction
3
Author: [akensert](http://github.com/akensert)
4
Date created: 2021/08/16
5
Last modified: 2021/12/27
6
Description: Implementation of an MPNN to predict blood-brain barrier permeability.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this tutorial, we will implement a type of graph neural network (GNN) known as
14
_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will
15
implement an MPNN to predict a molecular property known as
16
_blood-brain barrier permeability_ (BBBP).
17
18
Motivation: as molecules are naturally represented as an undirected graph `G = (V, E)`,
19
where `V` is a set or vertices (nodes; atoms) and `E` a set of edges (bonds), GNNs (such
20
as MPNN) are proving to be a useful method for predicting molecular properties.
21
22
Until now, more traditional methods, such as random forests, support vector machines, etc.,
23
have been commonly used to predict molecular properties. In contrast to GNNs, these
24
traditional approaches often operate on precomputed molecular features such as
25
molecular weight, polarity, charge, number of carbon atoms, etc. Although these
26
molecular features prove to be good predictors for various molecular properties, it is
27
hypothesized that operating on these more "raw", "low-level", features could prove even
28
better.
29
30
### References
31
32
In recent years, a lot of effort has been put into developing neural networks for
33
graph data, including molecular graphs. For a summary of graph neural networks, see e.g.,
34
[A Comprehensive Survey on Graph Neural Networks](https://arxiv.org/abs/1901.00596) and
35
[Graph Neural Networks: A Review of Methods and Applications](https://arxiv.org/abs/1812.08434);
36
and for further reading on the specific
37
graph neural network implemented in this tutorial see
38
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and
39
[DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).
40
"""
41
42
"""
43
## Setup
44
45
### Install RDKit and other dependencies
46
47
(Text below taken from
48
[this tutorial](https://keras.io/examples/generative/wgan-graphs/)).
49
50
[RDKit](https://www.rdkit.org/) is a collection of cheminformatics and machine-learning
51
software written in C++ and Python. In this tutorial, RDKit is used to conveniently and
52
efficiently transform
53
[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) to
54
molecule objects, and then from those obtain sets of atoms and bonds.
55
56
SMILES expresses the structure of a given molecule in the form of an ASCII string.
57
The SMILES string is a compact encoding which, for smaller molecules, is relatively
58
human-readable. Encoding molecules as a string both alleviates and facilitates database
59
and/or web searching of a given molecule. RDKit uses algorithms to
60
accurately transform a given SMILES to a molecule object, which can then
61
be used to compute a great number of molecular properties/features.
62
63
Notice, RDKit is commonly installed via [Conda](https://www.rdkit.org/docs/Install.html).
64
However, thanks to
65
[rdkit_platform_wheels](https://github.com/kuelumbus/rdkit_platform_wheels), rdkit
66
can now (for the sake of this tutorial) be installed easily via pip, as follows:
67
68
```
69
pip -q install rdkit-pypi
70
```
71
72
And for easy and efficient reading of csv files and visualization, the below needs to be
73
installed:
74
75
```
76
pip -q install pandas
77
pip -q install Pillow
78
pip -q install matplotlib
79
pip -q install pydot
80
sudo apt-get -qq install graphviz
81
```
82
"""
83
84
"""
85
### Import packages
86
"""
87
88
import os
89
90
# Temporary suppress tf logs
91
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
92
93
import tensorflow as tf
94
from tensorflow import keras
95
from tensorflow.keras import layers
96
import numpy as np
97
import pandas as pd
98
import matplotlib.pyplot as plt
99
import warnings
100
from rdkit import Chem
101
from rdkit import RDLogger
102
from rdkit.Chem.Draw import IPythonConsole
103
from rdkit.Chem.Draw import MolsToGridImage
104
105
# Temporary suppress warnings and RDKit logs
106
warnings.filterwarnings("ignore")
107
RDLogger.DisableLog("rdApp.*")
108
109
np.random.seed(42)
110
tf.random.set_seed(42)
111
112
"""
113
## Dataset
114
115
Information about the dataset can be found in
116
[A Bayesian Approach to in Silico Blood-Brain Barrier Penetration Modeling](https://pubs.acs.org/doi/10.1021/ci300124c)
117
and [MoleculeNet: A Benchmark for Molecular Machine Learning](https://arxiv.org/abs/1703.00564).
118
The dataset will be downloaded from [MoleculeNet.org](https://moleculenet.org/datasets-1).
119
120
### About
121
122
The dataset contains **2,050** molecules. Each molecule come with a **name**, **label**
123
and **SMILES** string.
124
125
The blood-brain barrier (BBB) is a membrane separating the blood from the brain
126
extracellular fluid, hence blocking out most drugs (molecules) from reaching
127
the brain. Because of this, the BBBP has been important to study for the development of
128
new drugs that aim to target the central nervous system. The labels for this
129
data set are binary (1 or 0) and indicate the permeability of the molecules.
130
"""
131
132
csv_path = keras.utils.get_file(
133
"BBBP.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
134
)
135
136
df = pd.read_csv(csv_path, usecols=[1, 2, 3])
137
df.iloc[96:104]
138
139
"""
140
### Define features
141
142
To encode features for atoms and bonds (which we will need later),
143
we'll define two classes: `AtomFeaturizer` and `BondFeaturizer` respectively.
144
145
To reduce the lines of code, i.e., to keep this tutorial short and concise,
146
only about a handful of (atom and bond) features will be considered: \[atom features\]
147
[symbol (element)](https://en.wikipedia.org/wiki/Chemical_element),
148
[number of valence electrons](https://en.wikipedia.org/wiki/Valence_electron),
149
[number of hydrogen bonds](https://en.wikipedia.org/wiki/Hydrogen),
150
[orbital hybridization](https://en.wikipedia.org/wiki/Orbital_hybridisation),
151
\[bond features\]
152
[(covalent) bond type](https://en.wikipedia.org/wiki/Covalent_bond), and
153
[conjugation](https://en.wikipedia.org/wiki/Conjugated_system).
154
"""
155
156
157
class Featurizer:
158
def __init__(self, allowable_sets):
159
self.dim = 0
160
self.features_mapping = {}
161
for k, s in allowable_sets.items():
162
s = sorted(list(s))
163
self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim)))
164
self.dim += len(s)
165
166
def encode(self, inputs):
167
output = np.zeros((self.dim,))
168
for name_feature, feature_mapping in self.features_mapping.items():
169
feature = getattr(self, name_feature)(inputs)
170
if feature not in feature_mapping:
171
continue
172
output[feature_mapping[feature]] = 1.0
173
return output
174
175
176
class AtomFeaturizer(Featurizer):
177
def __init__(self, allowable_sets):
178
super().__init__(allowable_sets)
179
180
def symbol(self, atom):
181
return atom.GetSymbol()
182
183
def n_valence(self, atom):
184
return atom.GetTotalValence()
185
186
def n_hydrogens(self, atom):
187
return atom.GetTotalNumHs()
188
189
def hybridization(self, atom):
190
return atom.GetHybridization().name.lower()
191
192
193
class BondFeaturizer(Featurizer):
194
def __init__(self, allowable_sets):
195
super().__init__(allowable_sets)
196
self.dim += 1
197
198
def encode(self, bond):
199
output = np.zeros((self.dim,))
200
if bond is None:
201
output[-1] = 1.0
202
return output
203
output = super().encode(bond)
204
return output
205
206
def bond_type(self, bond):
207
return bond.GetBondType().name.lower()
208
209
def conjugated(self, bond):
210
return bond.GetIsConjugated()
211
212
213
atom_featurizer = AtomFeaturizer(
214
allowable_sets={
215
"symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
216
"n_valence": {0, 1, 2, 3, 4, 5, 6},
217
"n_hydrogens": {0, 1, 2, 3, 4},
218
"hybridization": {"s", "sp", "sp2", "sp3"},
219
}
220
)
221
222
bond_featurizer = BondFeaturizer(
223
allowable_sets={
224
"bond_type": {"single", "double", "triple", "aromatic"},
225
"conjugated": {True, False},
226
}
227
)
228
229
230
"""
231
### Generate graphs
232
233
Before we can generate complete graphs from SMILES, we need to implement the following functions:
234
235
1. `molecule_from_smiles`, which takes as input a SMILES and returns a molecule object.
236
This is all handled by RDKit.
237
238
2. `graph_from_molecule`, which takes as input a molecule object and returns a graph,
239
represented as a three-tuple (atom_features, bond_features, pair_indices). For this we
240
will make use of the classes defined previously.
241
242
Finally, we can now implement the function `graphs_from_smiles`, which applies function (1)
243
and subsequently (2) on all SMILES of the training, validation and test datasets.
244
245
Notice: although scaffold splitting is recommended for this data set (see
246
[here](https://arxiv.org/abs/1703.00564)), for simplicity, simple random splittings were
247
performed.
248
"""
249
250
251
def molecule_from_smiles(smiles):
252
# MolFromSmiles(m, sanitize=True) should be equivalent to
253
# MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...)
254
molecule = Chem.MolFromSmiles(smiles, sanitize=False)
255
256
# If sanitization is unsuccessful, catch the error, and try again without
257
# the sanitization step that caused the error
258
flag = Chem.SanitizeMol(molecule, catchErrors=True)
259
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
260
Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
261
262
Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
263
return molecule
264
265
266
def graph_from_molecule(molecule):
267
# Initialize graph
268
atom_features = []
269
bond_features = []
270
pair_indices = []
271
272
for atom in molecule.GetAtoms():
273
atom_features.append(atom_featurizer.encode(atom))
274
275
# Add self-loops
276
pair_indices.append([atom.GetIdx(), atom.GetIdx()])
277
bond_features.append(bond_featurizer.encode(None))
278
279
for neighbor in atom.GetNeighbors():
280
bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
281
pair_indices.append([atom.GetIdx(), neighbor.GetIdx()])
282
bond_features.append(bond_featurizer.encode(bond))
283
284
return np.array(atom_features), np.array(bond_features), np.array(pair_indices)
285
286
287
def graphs_from_smiles(smiles_list):
288
# Initialize graphs
289
atom_features_list = []
290
bond_features_list = []
291
pair_indices_list = []
292
293
for smiles in smiles_list:
294
molecule = molecule_from_smiles(smiles)
295
atom_features, bond_features, pair_indices = graph_from_molecule(molecule)
296
297
atom_features_list.append(atom_features)
298
bond_features_list.append(bond_features)
299
pair_indices_list.append(pair_indices)
300
301
# Convert lists to ragged tensors for tf.data.Dataset later on
302
return (
303
tf.ragged.constant(atom_features_list, dtype=tf.float32),
304
tf.ragged.constant(bond_features_list, dtype=tf.float32),
305
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
306
)
307
308
309
# Shuffle array of indices ranging from 0 to 2049
310
permuted_indices = np.random.permutation(np.arange(df.shape[0]))
311
312
# Train set: 80 % of data
313
train_index = permuted_indices[: int(df.shape[0] * 0.8)]
314
x_train = graphs_from_smiles(df.iloc[train_index].smiles)
315
y_train = df.iloc[train_index].p_np
316
317
# Valid set: 19 % of data
318
valid_index = permuted_indices[int(df.shape[0] * 0.8) : int(df.shape[0] * 0.99)]
319
x_valid = graphs_from_smiles(df.iloc[valid_index].smiles)
320
y_valid = df.iloc[valid_index].p_np
321
322
# Test set: 1 % of data
323
test_index = permuted_indices[int(df.shape[0] * 0.99) :]
324
x_test = graphs_from_smiles(df.iloc[test_index].smiles)
325
y_test = df.iloc[test_index].p_np
326
327
"""
328
### Test the functions
329
"""
330
331
print(f"Name:\t{df.name[100]}\nSMILES:\t{df.smiles[100]}\nBBBP:\t{df.p_np[100]}")
332
molecule = molecule_from_smiles(df.iloc[100].smiles)
333
print("Molecule:")
334
molecule
335
336
"""
337
"""
338
339
graph = graph_from_molecule(molecule)
340
print("Graph (including self-loops):")
341
print("\tatom features\t", graph[0].shape)
342
print("\tbond features\t", graph[1].shape)
343
print("\tpair indices\t", graph[2].shape)
344
345
346
"""
347
### Create a `tf.data.Dataset`
348
349
In this tutorial, the MPNN implementation will take as input (per iteration) a single graph.
350
Therefore, given a batch of (sub)graphs (molecules), we need to merge them into a
351
single graph (we'll refer to this graph as *global graph*).
352
This global graph is a disconnected graph where each subgraph is
353
completely separated from the other subgraphs.
354
"""
355
356
357
def prepare_batch(x_batch, y_batch):
358
"""Merges (sub)graphs of batch into a single global (disconnected) graph"""
359
360
atom_features, bond_features, pair_indices = x_batch
361
362
# Obtain number of atoms and bonds for each graph (molecule)
363
num_atoms = atom_features.row_lengths()
364
num_bonds = bond_features.row_lengths()
365
366
# Obtain partition indices (molecule_indicator), which will be used to
367
# gather (sub)graphs from global graph in model later on
368
molecule_indices = tf.range(len(num_atoms))
369
molecule_indicator = tf.repeat(molecule_indices, num_atoms)
370
371
# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
372
# 'pair_indices' (and merging ragged tensors) actualizes the global graph
373
gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
374
increment = tf.cumsum(num_atoms[:-1])
375
increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
376
pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
377
pair_indices = pair_indices + increment[:, tf.newaxis]
378
atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
379
bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
380
381
return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch
382
383
384
def MPNNDataset(X, y, batch_size=32, shuffle=False):
385
dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
386
if shuffle:
387
dataset = dataset.shuffle(1024)
388
return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1)
389
390
391
"""
392
## Model
393
394
The MPNN model can take on various shapes and forms. In this tutorial, we will implement an
395
MPNN based on the original paper
396
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and
397
[DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).
398
The MPNN of this tutorial consists of three stages: message passing, readout and
399
classification.
400
401
402
### Message passing
403
404
The message passing step itself consists of two parts:
405
406
1. The *edge network*, which passes messages from 1-hop neighbors `w_{i}` of `v`
407
to `v`, based on the edge features between them (`e_{vw_{i}}`),
408
resulting in an updated node (state) `v'`. `w_{i}` denotes the `i:th` neighbor of
409
`v`.
410
411
2. The *gated recurrent unit* (GRU), which takes as input the most recent node state
412
and updates it based on previous node states. In
413
other words, the most recent node state serves as the input to the GRU, while the previous
414
node states are incorporated within the memory state of the GRU. This allows information
415
to travel from one node state (e.g., `v`) to another (e.g., `v''`).
416
417
Importantly, step (1) and (2) are repeated for `k steps`, and where at each step `1...k`,
418
the radius (or number of hops) of aggregated information from `v` increases by 1.
419
"""
420
421
422
class EdgeNetwork(layers.Layer):
423
def build(self, input_shape):
424
self.atom_dim = input_shape[0][-1]
425
self.bond_dim = input_shape[1][-1]
426
self.kernel = self.add_weight(
427
shape=(self.bond_dim, self.atom_dim * self.atom_dim),
428
initializer="glorot_uniform",
429
name="kernel",
430
)
431
self.bias = self.add_weight(
432
shape=(self.atom_dim * self.atom_dim),
433
initializer="zeros",
434
name="bias",
435
)
436
self.built = True
437
438
def call(self, inputs):
439
atom_features, bond_features, pair_indices = inputs
440
441
# Apply linear transformation to bond features
442
bond_features = tf.matmul(bond_features, self.kernel) + self.bias
443
444
# Reshape for neighborhood aggregation later
445
bond_features = tf.reshape(bond_features, (-1, self.atom_dim, self.atom_dim))
446
447
# Obtain atom features of neighbors
448
atom_features_neighbors = tf.gather(atom_features, pair_indices[:, 1])
449
atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1)
450
451
# Apply neighborhood aggregation
452
transformed_features = tf.matmul(bond_features, atom_features_neighbors)
453
transformed_features = tf.squeeze(transformed_features, axis=-1)
454
aggregated_features = tf.math.unsorted_segment_sum(
455
transformed_features,
456
pair_indices[:, 0],
457
num_segments=tf.shape(atom_features)[0],
458
)
459
return aggregated_features
460
461
462
class MessagePassing(layers.Layer):
463
def __init__(self, units, steps=4, **kwargs):
464
super().__init__(**kwargs)
465
self.units = units
466
self.steps = steps
467
468
def build(self, input_shape):
469
self.atom_dim = input_shape[0][-1]
470
self.message_step = EdgeNetwork()
471
self.pad_length = max(0, self.units - self.atom_dim)
472
self.update_step = layers.GRUCell(self.atom_dim + self.pad_length)
473
self.built = True
474
475
def call(self, inputs):
476
atom_features, bond_features, pair_indices = inputs
477
478
# Pad atom features if number of desired units exceeds atom_features dim.
479
# Alternatively, a dense layer could be used here.
480
atom_features_updated = tf.pad(atom_features, [(0, 0), (0, self.pad_length)])
481
482
# Perform a number of steps of message passing
483
for i in range(self.steps):
484
# Aggregate information from neighbors
485
atom_features_aggregated = self.message_step(
486
[atom_features_updated, bond_features, pair_indices]
487
)
488
489
# Update node state via a step of GRU
490
atom_features_updated, _ = self.update_step(
491
atom_features_aggregated, atom_features_updated
492
)
493
return atom_features_updated
494
495
496
"""
497
### Readout
498
499
When the message passing procedure ends, the k-step-aggregated node states are to be partitioned
500
into subgraphs (corresponding to each molecule in the batch) and subsequently
501
reduced to graph-level embeddings. In the
502
[original paper](https://arxiv.org/abs/1704.01212), a
503
[set-to-set layer](https://arxiv.org/abs/1511.06391) was used for this purpose.
504
In this tutorial however, a transformer encoder + average pooling will be used. Specifically:
505
506
* the k-step-aggregated node states will be partitioned into the subgraphs
507
(corresponding to each molecule in the batch);
508
* each subgraph will then be padded to match the subgraph with the greatest number of nodes, followed
509
by a `tf.stack(...)`;
510
* the (stacked padded) tensor, encoding subgraphs (each subgraph containing a set of node states), are
511
masked to make sure the paddings don't interfere with training;
512
* finally, the tensor is passed to the transformer followed by average pooling.
513
"""
514
515
516
class PartitionPadding(layers.Layer):
517
def __init__(self, batch_size, **kwargs):
518
super().__init__(**kwargs)
519
self.batch_size = batch_size
520
521
def call(self, inputs):
522
atom_features, molecule_indicator = inputs
523
524
# Obtain subgraphs
525
atom_features_partitioned = tf.dynamic_partition(
526
atom_features, molecule_indicator, self.batch_size
527
)
528
529
# Pad and stack subgraphs
530
num_atoms = [tf.shape(f)[0] for f in atom_features_partitioned]
531
max_num_atoms = tf.reduce_max(num_atoms)
532
atom_features_stacked = tf.stack(
533
[
534
tf.pad(f, [(0, max_num_atoms - n), (0, 0)])
535
for f, n in zip(atom_features_partitioned, num_atoms)
536
],
537
axis=0,
538
)
539
540
# Remove empty subgraphs (usually for last batch in dataset)
541
gather_indices = tf.where(tf.reduce_sum(atom_features_stacked, (1, 2)) != 0)
542
gather_indices = tf.squeeze(gather_indices, axis=-1)
543
return tf.gather(atom_features_stacked, gather_indices, axis=0)
544
545
546
class TransformerEncoderReadout(layers.Layer):
547
def __init__(
548
self, num_heads=8, embed_dim=64, dense_dim=512, batch_size=32, **kwargs
549
):
550
super().__init__(**kwargs)
551
552
self.partition_padding = PartitionPadding(batch_size)
553
self.attention = layers.MultiHeadAttention(num_heads, embed_dim)
554
self.dense_proj = keras.Sequential(
555
[
556
layers.Dense(dense_dim, activation="relu"),
557
layers.Dense(embed_dim),
558
]
559
)
560
self.layernorm_1 = layers.LayerNormalization()
561
self.layernorm_2 = layers.LayerNormalization()
562
self.average_pooling = layers.GlobalAveragePooling1D()
563
564
def call(self, inputs):
565
x = self.partition_padding(inputs)
566
padding_mask = tf.reduce_any(tf.not_equal(x, 0.0), axis=-1)
567
padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
568
attention_output = self.attention(x, x, attention_mask=padding_mask)
569
proj_input = self.layernorm_1(x + attention_output)
570
proj_output = self.layernorm_2(proj_input + self.dense_proj(proj_input))
571
return self.average_pooling(proj_output)
572
573
574
"""
575
### Message Passing Neural Network (MPNN)
576
577
It is now time to complete the MPNN model. In addition to the message passing
578
and readout, a two-layer classification network will be implemented to make
579
predictions of BBBP.
580
"""
581
582
583
def MPNNModel(
584
atom_dim,
585
bond_dim,
586
batch_size=32,
587
message_units=64,
588
message_steps=4,
589
num_attention_heads=8,
590
dense_units=512,
591
):
592
atom_features = layers.Input((atom_dim), dtype="float32", name="atom_features")
593
bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features")
594
pair_indices = layers.Input((2), dtype="int32", name="pair_indices")
595
molecule_indicator = layers.Input((), dtype="int32", name="molecule_indicator")
596
597
x = MessagePassing(message_units, message_steps)(
598
[atom_features, bond_features, pair_indices]
599
)
600
601
x = TransformerEncoderReadout(
602
num_attention_heads, message_units, dense_units, batch_size
603
)([x, molecule_indicator])
604
605
x = layers.Dense(dense_units, activation="relu")(x)
606
x = layers.Dense(1, activation="sigmoid")(x)
607
608
model = keras.Model(
609
inputs=[atom_features, bond_features, pair_indices, molecule_indicator],
610
outputs=[x],
611
)
612
return model
613
614
615
mpnn = MPNNModel(
616
atom_dim=x_train[0][0][0].shape[0],
617
bond_dim=x_train[1][0][0].shape[0],
618
)
619
620
mpnn.compile(
621
loss=keras.losses.BinaryCrossentropy(),
622
optimizer=keras.optimizers.Adam(learning_rate=5e-4),
623
metrics=[keras.metrics.AUC(name="AUC")],
624
)
625
626
keras.utils.plot_model(mpnn, show_dtype=True, show_shapes=True)
627
628
"""
629
### Training
630
"""
631
632
train_dataset = MPNNDataset(x_train, y_train)
633
valid_dataset = MPNNDataset(x_valid, y_valid)
634
test_dataset = MPNNDataset(x_test, y_test)
635
636
history = mpnn.fit(
637
train_dataset,
638
validation_data=valid_dataset,
639
epochs=40,
640
verbose=2,
641
class_weight={0: 2.0, 1: 0.5},
642
)
643
644
plt.figure(figsize=(10, 6))
645
plt.plot(history.history["AUC"], label="train AUC")
646
plt.plot(history.history["val_AUC"], label="valid AUC")
647
plt.xlabel("Epochs", fontsize=16)
648
plt.ylabel("AUC", fontsize=16)
649
plt.legend(fontsize=16)
650
651
"""
652
### Predicting
653
"""
654
655
molecules = [molecule_from_smiles(df.smiles.values[index]) for index in test_index]
656
y_true = [df.p_np.values[index] for index in test_index]
657
y_pred = tf.squeeze(mpnn.predict(test_dataset), axis=1)
658
659
legends = [f"y_true/y_pred = {y_true[i]}/{y_pred[i]:.2f}" for i in range(len(y_true))]
660
MolsToGridImage(molecules, molsPerRow=4, legends=legends)
661
662
"""
663
## Conclusions
664
665
In this tutorial, we demonstrated a message passing neural network (MPNN) to
666
predict blood-brain barrier permeability (BBBP) for a number of different molecules. We
667
first had to construct graphs from SMILES, then build a Keras model that could
668
operate on these graphs, and finally train the model to make the predictions.
669
670
Example available on HuggingFace
671
672
| Trained Model | Demo |
673
| :--: | :--: |
674
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-mpnn%20molecular%20graphs-black.svg)](https://huggingface.co/keras-io/MPNN-for-molecular-property-prediction) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-mpnn%20molecular%20graphs-black.svg)](https://huggingface.co/spaces/keras-io/molecular-property-prediction) |
675
"""
676
677