Path: blob/master/examples/graph/mpnn-molecular-graphs.py
3507 views
"""1Title: Message-passing neural network (MPNN) for molecular property prediction2Author: [akensert](http://github.com/akensert)3Date created: 2021/08/164Last modified: 2021/12/275Description: Implementation of an MPNN to predict blood-brain barrier permeability.6Accelerator: GPU7"""89"""10## Introduction1112In this tutorial, we will implement a type of graph neural network (GNN) known as13_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will14implement an MPNN to predict a molecular property known as15_blood-brain barrier permeability_ (BBBP).1617Motivation: as molecules are naturally represented as an undirected graph `G = (V, E)`,18where `V` is a set or vertices (nodes; atoms) and `E` a set of edges (bonds), GNNs (such19as MPNN) are proving to be a useful method for predicting molecular properties.2021Until now, more traditional methods, such as random forests, support vector machines, etc.,22have been commonly used to predict molecular properties. In contrast to GNNs, these23traditional approaches often operate on precomputed molecular features such as24molecular weight, polarity, charge, number of carbon atoms, etc. Although these25molecular features prove to be good predictors for various molecular properties, it is26hypothesized that operating on these more "raw", "low-level", features could prove even27better.2829### References3031In recent years, a lot of effort has been put into developing neural networks for32graph data, including molecular graphs. For a summary of graph neural networks, see e.g.,33[A Comprehensive Survey on Graph Neural Networks](https://arxiv.org/abs/1901.00596) and34[Graph Neural Networks: A Review of Methods and Applications](https://arxiv.org/abs/1812.08434);35and for further reading on the specific36graph neural network implemented in this tutorial see37[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and38[DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).39"""4041"""42## Setup4344### Install RDKit and other dependencies4546(Text below taken from47[this tutorial](https://keras.io/examples/generative/wgan-graphs/)).4849[RDKit](https://www.rdkit.org/) is a collection of cheminformatics and machine-learning50software written in C++ and Python. In this tutorial, RDKit is used to conveniently and51efficiently transform52[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) to53molecule objects, and then from those obtain sets of atoms and bonds.5455SMILES expresses the structure of a given molecule in the form of an ASCII string.56The SMILES string is a compact encoding which, for smaller molecules, is relatively57human-readable. Encoding molecules as a string both alleviates and facilitates database58and/or web searching of a given molecule. RDKit uses algorithms to59accurately transform a given SMILES to a molecule object, which can then60be used to compute a great number of molecular properties/features.6162Notice, RDKit is commonly installed via [Conda](https://www.rdkit.org/docs/Install.html).63However, thanks to64[rdkit_platform_wheels](https://github.com/kuelumbus/rdkit_platform_wheels), rdkit65can now (for the sake of this tutorial) be installed easily via pip, as follows:6667```68pip -q install rdkit-pypi69```7071And for easy and efficient reading of csv files and visualization, the below needs to be72installed:7374```75pip -q install pandas76pip -q install Pillow77pip -q install matplotlib78pip -q install pydot79sudo apt-get -qq install graphviz80```81"""8283"""84### Import packages85"""8687import os8889# Temporary suppress tf logs90os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"9192import tensorflow as tf93from tensorflow import keras94from tensorflow.keras import layers95import numpy as np96import pandas as pd97import matplotlib.pyplot as plt98import warnings99from rdkit import Chem100from rdkit import RDLogger101from rdkit.Chem.Draw import IPythonConsole102from rdkit.Chem.Draw import MolsToGridImage103104# Temporary suppress warnings and RDKit logs105warnings.filterwarnings("ignore")106RDLogger.DisableLog("rdApp.*")107108np.random.seed(42)109tf.random.set_seed(42)110111"""112## Dataset113114Information about the dataset can be found in115[A Bayesian Approach to in Silico Blood-Brain Barrier Penetration Modeling](https://pubs.acs.org/doi/10.1021/ci300124c)116and [MoleculeNet: A Benchmark for Molecular Machine Learning](https://arxiv.org/abs/1703.00564).117The dataset will be downloaded from [MoleculeNet.org](https://moleculenet.org/datasets-1).118119### About120121The dataset contains **2,050** molecules. Each molecule come with a **name**, **label**122and **SMILES** string.123124The blood-brain barrier (BBB) is a membrane separating the blood from the brain125extracellular fluid, hence blocking out most drugs (molecules) from reaching126the brain. Because of this, the BBBP has been important to study for the development of127new drugs that aim to target the central nervous system. The labels for this128data set are binary (1 or 0) and indicate the permeability of the molecules.129"""130131csv_path = keras.utils.get_file(132"BBBP.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"133)134135df = pd.read_csv(csv_path, usecols=[1, 2, 3])136df.iloc[96:104]137138"""139### Define features140141To encode features for atoms and bonds (which we will need later),142we'll define two classes: `AtomFeaturizer` and `BondFeaturizer` respectively.143144To reduce the lines of code, i.e., to keep this tutorial short and concise,145only about a handful of (atom and bond) features will be considered: \[atom features\]146[symbol (element)](https://en.wikipedia.org/wiki/Chemical_element),147[number of valence electrons](https://en.wikipedia.org/wiki/Valence_electron),148[number of hydrogen bonds](https://en.wikipedia.org/wiki/Hydrogen),149[orbital hybridization](https://en.wikipedia.org/wiki/Orbital_hybridisation),150\[bond features\]151[(covalent) bond type](https://en.wikipedia.org/wiki/Covalent_bond), and152[conjugation](https://en.wikipedia.org/wiki/Conjugated_system).153"""154155156class Featurizer:157def __init__(self, allowable_sets):158self.dim = 0159self.features_mapping = {}160for k, s in allowable_sets.items():161s = sorted(list(s))162self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim)))163self.dim += len(s)164165def encode(self, inputs):166output = np.zeros((self.dim,))167for name_feature, feature_mapping in self.features_mapping.items():168feature = getattr(self, name_feature)(inputs)169if feature not in feature_mapping:170continue171output[feature_mapping[feature]] = 1.0172return output173174175class AtomFeaturizer(Featurizer):176def __init__(self, allowable_sets):177super().__init__(allowable_sets)178179def symbol(self, atom):180return atom.GetSymbol()181182def n_valence(self, atom):183return atom.GetTotalValence()184185def n_hydrogens(self, atom):186return atom.GetTotalNumHs()187188def hybridization(self, atom):189return atom.GetHybridization().name.lower()190191192class BondFeaturizer(Featurizer):193def __init__(self, allowable_sets):194super().__init__(allowable_sets)195self.dim += 1196197def encode(self, bond):198output = np.zeros((self.dim,))199if bond is None:200output[-1] = 1.0201return output202output = super().encode(bond)203return output204205def bond_type(self, bond):206return bond.GetBondType().name.lower()207208def conjugated(self, bond):209return bond.GetIsConjugated()210211212atom_featurizer = AtomFeaturizer(213allowable_sets={214"symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},215"n_valence": {0, 1, 2, 3, 4, 5, 6},216"n_hydrogens": {0, 1, 2, 3, 4},217"hybridization": {"s", "sp", "sp2", "sp3"},218}219)220221bond_featurizer = BondFeaturizer(222allowable_sets={223"bond_type": {"single", "double", "triple", "aromatic"},224"conjugated": {True, False},225}226)227228229"""230### Generate graphs231232Before we can generate complete graphs from SMILES, we need to implement the following functions:2332341. `molecule_from_smiles`, which takes as input a SMILES and returns a molecule object.235This is all handled by RDKit.2362372. `graph_from_molecule`, which takes as input a molecule object and returns a graph,238represented as a three-tuple (atom_features, bond_features, pair_indices). For this we239will make use of the classes defined previously.240241Finally, we can now implement the function `graphs_from_smiles`, which applies function (1)242and subsequently (2) on all SMILES of the training, validation and test datasets.243244Notice: although scaffold splitting is recommended for this data set (see245[here](https://arxiv.org/abs/1703.00564)), for simplicity, simple random splittings were246performed.247"""248249250def molecule_from_smiles(smiles):251# MolFromSmiles(m, sanitize=True) should be equivalent to252# MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...)253molecule = Chem.MolFromSmiles(smiles, sanitize=False)254255# If sanitization is unsuccessful, catch the error, and try again without256# the sanitization step that caused the error257flag = Chem.SanitizeMol(molecule, catchErrors=True)258if flag != Chem.SanitizeFlags.SANITIZE_NONE:259Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)260261Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)262return molecule263264265def graph_from_molecule(molecule):266# Initialize graph267atom_features = []268bond_features = []269pair_indices = []270271for atom in molecule.GetAtoms():272atom_features.append(atom_featurizer.encode(atom))273274# Add self-loops275pair_indices.append([atom.GetIdx(), atom.GetIdx()])276bond_features.append(bond_featurizer.encode(None))277278for neighbor in atom.GetNeighbors():279bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())280pair_indices.append([atom.GetIdx(), neighbor.GetIdx()])281bond_features.append(bond_featurizer.encode(bond))282283return np.array(atom_features), np.array(bond_features), np.array(pair_indices)284285286def graphs_from_smiles(smiles_list):287# Initialize graphs288atom_features_list = []289bond_features_list = []290pair_indices_list = []291292for smiles in smiles_list:293molecule = molecule_from_smiles(smiles)294atom_features, bond_features, pair_indices = graph_from_molecule(molecule)295296atom_features_list.append(atom_features)297bond_features_list.append(bond_features)298pair_indices_list.append(pair_indices)299300# Convert lists to ragged tensors for tf.data.Dataset later on301return (302tf.ragged.constant(atom_features_list, dtype=tf.float32),303tf.ragged.constant(bond_features_list, dtype=tf.float32),304tf.ragged.constant(pair_indices_list, dtype=tf.int64),305)306307308# Shuffle array of indices ranging from 0 to 2049309permuted_indices = np.random.permutation(np.arange(df.shape[0]))310311# Train set: 80 % of data312train_index = permuted_indices[: int(df.shape[0] * 0.8)]313x_train = graphs_from_smiles(df.iloc[train_index].smiles)314y_train = df.iloc[train_index].p_np315316# Valid set: 19 % of data317valid_index = permuted_indices[int(df.shape[0] * 0.8) : int(df.shape[0] * 0.99)]318x_valid = graphs_from_smiles(df.iloc[valid_index].smiles)319y_valid = df.iloc[valid_index].p_np320321# Test set: 1 % of data322test_index = permuted_indices[int(df.shape[0] * 0.99) :]323x_test = graphs_from_smiles(df.iloc[test_index].smiles)324y_test = df.iloc[test_index].p_np325326"""327### Test the functions328"""329330print(f"Name:\t{df.name[100]}\nSMILES:\t{df.smiles[100]}\nBBBP:\t{df.p_np[100]}")331molecule = molecule_from_smiles(df.iloc[100].smiles)332print("Molecule:")333molecule334335"""336"""337338graph = graph_from_molecule(molecule)339print("Graph (including self-loops):")340print("\tatom features\t", graph[0].shape)341print("\tbond features\t", graph[1].shape)342print("\tpair indices\t", graph[2].shape)343344345"""346### Create a `tf.data.Dataset`347348In this tutorial, the MPNN implementation will take as input (per iteration) a single graph.349Therefore, given a batch of (sub)graphs (molecules), we need to merge them into a350single graph (we'll refer to this graph as *global graph*).351This global graph is a disconnected graph where each subgraph is352completely separated from the other subgraphs.353"""354355356def prepare_batch(x_batch, y_batch):357"""Merges (sub)graphs of batch into a single global (disconnected) graph"""358359atom_features, bond_features, pair_indices = x_batch360361# Obtain number of atoms and bonds for each graph (molecule)362num_atoms = atom_features.row_lengths()363num_bonds = bond_features.row_lengths()364365# Obtain partition indices (molecule_indicator), which will be used to366# gather (sub)graphs from global graph in model later on367molecule_indices = tf.range(len(num_atoms))368molecule_indicator = tf.repeat(molecule_indices, num_atoms)369370# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to371# 'pair_indices' (and merging ragged tensors) actualizes the global graph372gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])373increment = tf.cumsum(num_atoms[:-1])374increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])375pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()376pair_indices = pair_indices + increment[:, tf.newaxis]377atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()378bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()379380return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch381382383def MPNNDataset(X, y, batch_size=32, shuffle=False):384dataset = tf.data.Dataset.from_tensor_slices((X, (y)))385if shuffle:386dataset = dataset.shuffle(1024)387return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1)388389390"""391## Model392393The MPNN model can take on various shapes and forms. In this tutorial, we will implement an394MPNN based on the original paper395[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and396[DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).397The MPNN of this tutorial consists of three stages: message passing, readout and398classification.399400401### Message passing402403The message passing step itself consists of two parts:4044051. The *edge network*, which passes messages from 1-hop neighbors `w_{i}` of `v`406to `v`, based on the edge features between them (`e_{vw_{i}}`),407resulting in an updated node (state) `v'`. `w_{i}` denotes the `i:th` neighbor of408`v`.4094102. The *gated recurrent unit* (GRU), which takes as input the most recent node state411and updates it based on previous node states. In412other words, the most recent node state serves as the input to the GRU, while the previous413node states are incorporated within the memory state of the GRU. This allows information414to travel from one node state (e.g., `v`) to another (e.g., `v''`).415416Importantly, step (1) and (2) are repeated for `k steps`, and where at each step `1...k`,417the radius (or number of hops) of aggregated information from `v` increases by 1.418"""419420421class EdgeNetwork(layers.Layer):422def build(self, input_shape):423self.atom_dim = input_shape[0][-1]424self.bond_dim = input_shape[1][-1]425self.kernel = self.add_weight(426shape=(self.bond_dim, self.atom_dim * self.atom_dim),427initializer="glorot_uniform",428name="kernel",429)430self.bias = self.add_weight(431shape=(self.atom_dim * self.atom_dim),432initializer="zeros",433name="bias",434)435self.built = True436437def call(self, inputs):438atom_features, bond_features, pair_indices = inputs439440# Apply linear transformation to bond features441bond_features = tf.matmul(bond_features, self.kernel) + self.bias442443# Reshape for neighborhood aggregation later444bond_features = tf.reshape(bond_features, (-1, self.atom_dim, self.atom_dim))445446# Obtain atom features of neighbors447atom_features_neighbors = tf.gather(atom_features, pair_indices[:, 1])448atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1)449450# Apply neighborhood aggregation451transformed_features = tf.matmul(bond_features, atom_features_neighbors)452transformed_features = tf.squeeze(transformed_features, axis=-1)453aggregated_features = tf.math.unsorted_segment_sum(454transformed_features,455pair_indices[:, 0],456num_segments=tf.shape(atom_features)[0],457)458return aggregated_features459460461class MessagePassing(layers.Layer):462def __init__(self, units, steps=4, **kwargs):463super().__init__(**kwargs)464self.units = units465self.steps = steps466467def build(self, input_shape):468self.atom_dim = input_shape[0][-1]469self.message_step = EdgeNetwork()470self.pad_length = max(0, self.units - self.atom_dim)471self.update_step = layers.GRUCell(self.atom_dim + self.pad_length)472self.built = True473474def call(self, inputs):475atom_features, bond_features, pair_indices = inputs476477# Pad atom features if number of desired units exceeds atom_features dim.478# Alternatively, a dense layer could be used here.479atom_features_updated = tf.pad(atom_features, [(0, 0), (0, self.pad_length)])480481# Perform a number of steps of message passing482for i in range(self.steps):483# Aggregate information from neighbors484atom_features_aggregated = self.message_step(485[atom_features_updated, bond_features, pair_indices]486)487488# Update node state via a step of GRU489atom_features_updated, _ = self.update_step(490atom_features_aggregated, atom_features_updated491)492return atom_features_updated493494495"""496### Readout497498When the message passing procedure ends, the k-step-aggregated node states are to be partitioned499into subgraphs (corresponding to each molecule in the batch) and subsequently500reduced to graph-level embeddings. In the501[original paper](https://arxiv.org/abs/1704.01212), a502[set-to-set layer](https://arxiv.org/abs/1511.06391) was used for this purpose.503In this tutorial however, a transformer encoder + average pooling will be used. Specifically:504505* the k-step-aggregated node states will be partitioned into the subgraphs506(corresponding to each molecule in the batch);507* each subgraph will then be padded to match the subgraph with the greatest number of nodes, followed508by a `tf.stack(...)`;509* the (stacked padded) tensor, encoding subgraphs (each subgraph containing a set of node states), are510masked to make sure the paddings don't interfere with training;511* finally, the tensor is passed to the transformer followed by average pooling.512"""513514515class PartitionPadding(layers.Layer):516def __init__(self, batch_size, **kwargs):517super().__init__(**kwargs)518self.batch_size = batch_size519520def call(self, inputs):521atom_features, molecule_indicator = inputs522523# Obtain subgraphs524atom_features_partitioned = tf.dynamic_partition(525atom_features, molecule_indicator, self.batch_size526)527528# Pad and stack subgraphs529num_atoms = [tf.shape(f)[0] for f in atom_features_partitioned]530max_num_atoms = tf.reduce_max(num_atoms)531atom_features_stacked = tf.stack(532[533tf.pad(f, [(0, max_num_atoms - n), (0, 0)])534for f, n in zip(atom_features_partitioned, num_atoms)535],536axis=0,537)538539# Remove empty subgraphs (usually for last batch in dataset)540gather_indices = tf.where(tf.reduce_sum(atom_features_stacked, (1, 2)) != 0)541gather_indices = tf.squeeze(gather_indices, axis=-1)542return tf.gather(atom_features_stacked, gather_indices, axis=0)543544545class TransformerEncoderReadout(layers.Layer):546def __init__(547self, num_heads=8, embed_dim=64, dense_dim=512, batch_size=32, **kwargs548):549super().__init__(**kwargs)550551self.partition_padding = PartitionPadding(batch_size)552self.attention = layers.MultiHeadAttention(num_heads, embed_dim)553self.dense_proj = keras.Sequential(554[555layers.Dense(dense_dim, activation="relu"),556layers.Dense(embed_dim),557]558)559self.layernorm_1 = layers.LayerNormalization()560self.layernorm_2 = layers.LayerNormalization()561self.average_pooling = layers.GlobalAveragePooling1D()562563def call(self, inputs):564x = self.partition_padding(inputs)565padding_mask = tf.reduce_any(tf.not_equal(x, 0.0), axis=-1)566padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]567attention_output = self.attention(x, x, attention_mask=padding_mask)568proj_input = self.layernorm_1(x + attention_output)569proj_output = self.layernorm_2(proj_input + self.dense_proj(proj_input))570return self.average_pooling(proj_output)571572573"""574### Message Passing Neural Network (MPNN)575576It is now time to complete the MPNN model. In addition to the message passing577and readout, a two-layer classification network will be implemented to make578predictions of BBBP.579"""580581582def MPNNModel(583atom_dim,584bond_dim,585batch_size=32,586message_units=64,587message_steps=4,588num_attention_heads=8,589dense_units=512,590):591atom_features = layers.Input((atom_dim), dtype="float32", name="atom_features")592bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features")593pair_indices = layers.Input((2), dtype="int32", name="pair_indices")594molecule_indicator = layers.Input((), dtype="int32", name="molecule_indicator")595596x = MessagePassing(message_units, message_steps)(597[atom_features, bond_features, pair_indices]598)599600x = TransformerEncoderReadout(601num_attention_heads, message_units, dense_units, batch_size602)([x, molecule_indicator])603604x = layers.Dense(dense_units, activation="relu")(x)605x = layers.Dense(1, activation="sigmoid")(x)606607model = keras.Model(608inputs=[atom_features, bond_features, pair_indices, molecule_indicator],609outputs=[x],610)611return model612613614mpnn = MPNNModel(615atom_dim=x_train[0][0][0].shape[0],616bond_dim=x_train[1][0][0].shape[0],617)618619mpnn.compile(620loss=keras.losses.BinaryCrossentropy(),621optimizer=keras.optimizers.Adam(learning_rate=5e-4),622metrics=[keras.metrics.AUC(name="AUC")],623)624625keras.utils.plot_model(mpnn, show_dtype=True, show_shapes=True)626627"""628### Training629"""630631train_dataset = MPNNDataset(x_train, y_train)632valid_dataset = MPNNDataset(x_valid, y_valid)633test_dataset = MPNNDataset(x_test, y_test)634635history = mpnn.fit(636train_dataset,637validation_data=valid_dataset,638epochs=40,639verbose=2,640class_weight={0: 2.0, 1: 0.5},641)642643plt.figure(figsize=(10, 6))644plt.plot(history.history["AUC"], label="train AUC")645plt.plot(history.history["val_AUC"], label="valid AUC")646plt.xlabel("Epochs", fontsize=16)647plt.ylabel("AUC", fontsize=16)648plt.legend(fontsize=16)649650"""651### Predicting652"""653654molecules = [molecule_from_smiles(df.smiles.values[index]) for index in test_index]655y_true = [df.p_np.values[index] for index in test_index]656y_pred = tf.squeeze(mpnn.predict(test_dataset), axis=1)657658legends = [f"y_true/y_pred = {y_true[i]}/{y_pred[i]:.2f}" for i in range(len(y_true))]659MolsToGridImage(molecules, molsPerRow=4, legends=legends)660661"""662## Conclusions663664In this tutorial, we demonstrated a message passing neural network (MPNN) to665predict blood-brain barrier permeability (BBBP) for a number of different molecules. We666first had to construct graphs from SMILES, then build a Keras model that could667operate on these graphs, and finally train the model to make the predictions.668669Example available on HuggingFace670671| Trained Model | Demo |672| :--: | :--: |673| [](https://huggingface.co/keras-io/MPNN-for-molecular-property-prediction) | [](https://huggingface.co/spaces/keras-io/molecular-property-prediction) |674"""675676677