Path: blob/master/examples/graph/ipynb/mpnn-molecular-graphs.ipynb
3508 views
Message-passing neural network (MPNN) for molecular property prediction
Author: akensert
Date created: 2021/08/16
Last modified: 2021/12/27
Description: Implementation of an MPNN to predict blood-brain barrier permeability.
Introduction
In this tutorial, we will implement a type of graph neural network (GNN) known as _ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will implement an MPNN to predict a molecular property known as blood-brain barrier permeability (BBBP).
Motivation: as molecules are naturally represented as an undirected graph G = (V, E)
, where V
is a set or vertices (nodes; atoms) and E
a set of edges (bonds), GNNs (such as MPNN) are proving to be a useful method for predicting molecular properties.
Until now, more traditional methods, such as random forests, support vector machines, etc., have been commonly used to predict molecular properties. In contrast to GNNs, these traditional approaches often operate on precomputed molecular features such as molecular weight, polarity, charge, number of carbon atoms, etc. Although these molecular features prove to be good predictors for various molecular properties, it is hypothesized that operating on these more "raw", "low-level", features could prove even better.
References
In recent years, a lot of effort has been put into developing neural networks for graph data, including molecular graphs. For a summary of graph neural networks, see e.g., A Comprehensive Survey on Graph Neural Networks and Graph Neural Networks: A Review of Methods and Applications; and for further reading on the specific graph neural network implemented in this tutorial see Neural Message Passing for Quantum Chemistry and DeepChem's MPNNModel.
Setup
Install RDKit and other dependencies
(Text below taken from this tutorial).
RDKit is a collection of cheminformatics and machine-learning software written in C++ and Python. In this tutorial, RDKit is used to conveniently and efficiently transform SMILES to molecule objects, and then from those obtain sets of atoms and bonds.
SMILES expresses the structure of a given molecule in the form of an ASCII string. The SMILES string is a compact encoding which, for smaller molecules, is relatively human-readable. Encoding molecules as a string both alleviates and facilitates database and/or web searching of a given molecule. RDKit uses algorithms to accurately transform a given SMILES to a molecule object, which can then be used to compute a great number of molecular properties/features.
Notice, RDKit is commonly installed via Conda. However, thanks to rdkit_platform_wheels, rdkit can now (for the sake of this tutorial) be installed easily via pip, as follows:
And for easy and efficient reading of csv files and visualization, the below needs to be installed:
Import packages
Dataset
Information about the dataset can be found in A Bayesian Approach to in Silico Blood-Brain Barrier Penetration Modeling and MoleculeNet: A Benchmark for Molecular Machine Learning. The dataset will be downloaded from MoleculeNet.org.
About
The dataset contains 2,050 molecules. Each molecule come with a name, label and SMILES string.
The blood-brain barrier (BBB) is a membrane separating the blood from the brain extracellular fluid, hence blocking out most drugs (molecules) from reaching the brain. Because of this, the BBBP has been important to study for the development of new drugs that aim to target the central nervous system. The labels for this data set are binary (1 or 0) and indicate the permeability of the molecules.
Define features
To encode features for atoms and bonds (which we will need later), we'll define two classes: AtomFeaturizer
and BondFeaturizer
respectively.
To reduce the lines of code, i.e., to keep this tutorial short and concise, only about a handful of (atom and bond) features will be considered: [atom features] symbol (element), number of valence electrons, number of hydrogen bonds, orbital hybridization, [bond features] (covalent) bond type, and conjugation.
Generate graphs
Before we can generate complete graphs from SMILES, we need to implement the following functions:
molecule_from_smiles
, which takes as input a SMILES and returns a molecule object. This is all handled by RDKit.graph_from_molecule
, which takes as input a molecule object and returns a graph, represented as a three-tuple (atom_features, bond_features, pair_indices). For this we will make use of the classes defined previously.
Finally, we can now implement the function graphs_from_smiles
, which applies function (1) and subsequently (2) on all SMILES of the training, validation and test datasets.
Notice: although scaffold splitting is recommended for this data set (see here), for simplicity, simple random splittings were performed.
Test the functions
Create a tf.data.Dataset
In this tutorial, the MPNN implementation will take as input (per iteration) a single graph. Therefore, given a batch of (sub)graphs (molecules), we need to merge them into a single graph (we'll refer to this graph as global graph). This global graph is a disconnected graph where each subgraph is completely separated from the other subgraphs.
Model
The MPNN model can take on various shapes and forms. In this tutorial, we will implement an MPNN based on the original paper Neural Message Passing for Quantum Chemistry and DeepChem's MPNNModel. The MPNN of this tutorial consists of three stages: message passing, readout and classification.
Message passing
The message passing step itself consists of two parts:
The edge network, which passes messages from 1-hop neighbors
w_{i}
ofv
tov
, based on the edge features between them (e_{vw_{i}}
), resulting in an updated node (state)v'
.w_{i}
denotes thei:th
neighbor ofv
.The gated recurrent unit (GRU), which takes as input the most recent node state and updates it based on previous node states. In other words, the most recent node state serves as the input to the GRU, while the previous node states are incorporated within the memory state of the GRU. This allows information to travel from one node state (e.g.,
v
) to another (e.g.,v''
).
Importantly, step (1) and (2) are repeated for k steps
, and where at each step 1...k
, the radius (or number of hops) of aggregated information from v
increases by 1.
Readout
When the message passing procedure ends, the k-step-aggregated node states are to be partitioned into subgraphs (corresponding to each molecule in the batch) and subsequently reduced to graph-level embeddings. In the original paper, a set-to-set layer was used for this purpose. In this tutorial however, a transformer encoder + average pooling will be used. Specifically:
the k-step-aggregated node states will be partitioned into the subgraphs (corresponding to each molecule in the batch);
each subgraph will then be padded to match the subgraph with the greatest number of nodes, followed by a
tf.stack(...)
;the (stacked padded) tensor, encoding subgraphs (each subgraph containing a set of node states), are masked to make sure the paddings don't interfere with training;
finally, the tensor is passed to the transformer followed by average pooling.
Message Passing Neural Network (MPNN)
It is now time to complete the MPNN model. In addition to the message passing and readout, a two-layer classification network will be implemented to make predictions of BBBP.
Training
Predicting
Conclusions
In this tutorial, we demonstarted a message passing neural network (MPNN) to predict blood-brain barrier permeability (BBBP) for a number of different molecules. We first had to construct graphs from SMILES, then build a Keras model that could operate on these graphs, and finally train the model to make the predictions.
Example available on HuggingFace