Path: blob/master/examples/generative/wgan-graphs.py
3507 views
"""1Title: WGAN-GP with R-GCN for the generation of small molecular graphs2Author: [akensert](https://github.com/akensert)3Date created: 2021/06/304Last modified: 2021/06/305Description: Complete implementation of WGAN-GP with R-GCN to generate novel molecules.6Accelerator: GPU7"""89"""10## Introduction1112In this tutorial, we implement a generative model for graphs and use it to generate13novel molecules.1415Motivation: The [development of new drugs](https://en.wikipedia.org/wiki/Drug_development)16(molecules) can be extremely time-consuming and costly. The use of deep learning models17can alleviate the search for good candidate drugs, by predicting properties of known molecules18(e.g., solubility, toxicity, affinity to target protein, etc.). As the number of19possible molecules is astronomical, the space in which we search for/explore molecules is20just a fraction of the entire space. Therefore, it's arguably desirable to implement21generative models that can learn to generate novel molecules (which would otherwise have never been explored).2223### References (implementation)2425The implementation in this tutorial is based on/inspired by the26[MolGAN paper](https://arxiv.org/abs/1805.11973) and DeepChem's27[Basic MolGAN](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#basicmolganmod28el).2930### Further reading (generative models)31Recent implementations of generative models for molecular graphs also include32[Mol-CycleGAN](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-019-0404-1),33[GraphVAE](https://arxiv.org/abs/1802.03480) and34[JT-VAE](https://arxiv.org/abs/1802.04364). For more information on generative35adverserial networks, see [GAN](https://arxiv.org/abs/1406.2661),36[WGAN](https://arxiv.org/abs/1701.07875) and [WGAN-GP](https://arxiv.org/abs/1704.00028).3738"""3940"""41## Setup4243### Install RDKit4445[RDKit](https://www.rdkit.org/) is a collection of cheminformatics and machine-learning46software written in C++ and Python. In this tutorial, RDKit is used to conveniently and47efficiently transform48[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) to49molecule objects, and then from those obtain sets of atoms and bonds.5051SMILES expresses the structure of a given molecule in the form of an ASCII string.52The SMILES string is a compact encoding which, for smaller molecules, is relatively53human-readable. Encoding molecules as a string both alleviates and facilitates database54and/or web searching of a given molecule. RDKit uses algorithms to55accurately transform a given SMILES to a molecule object, which can then56be used to compute a great number of molecular properties/features.5758Notice, RDKit is commonly installed via [Conda](https://www.rdkit.org/docs/Install.html).59However, thanks to60[rdkit_platform_wheels](https://github.com/kuelumbus/rdkit_platform_wheels), rdkit61can now (for the sake of this tutorial) be installed easily via pip, as follows:62```63pip -q install rdkit-pypi64```65And to allow easy visualization of a molecule objects, Pillow needs to be installed:66```67pip -q install Pillow68```6970"""7172"""73### Import packages7475"""7677from rdkit import Chem, RDLogger78from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage79import numpy as np80import tensorflow as tf81from tensorflow import keras8283RDLogger.DisableLog("rdApp.*")8485"""86## Dataset8788The dataset used in this tutorial is a89[quantum mechanics dataset](http://quantum-machine.org/datasets/) (QM9), obtained from90[MoleculeNet](http://moleculenet.ai/datasets-1). Although many feature and label columns91come with the dataset, we'll only focus on the92[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system)93column. The QM9 dataset is a good first dataset to work with for generating94graphs, as the maximum number of heavy (non-hydrogen) atoms found in a molecule is only nine.95"""9697csv_path = tf.keras.utils.get_file(98"qm9.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv"99)100101data = []102with open(csv_path, "r") as f:103for line in f.readlines()[1:]:104data.append(line.split(",")[1])105106# Let's look at a molecule of the dataset107smiles = data[1000]108print("SMILES:", smiles)109molecule = Chem.MolFromSmiles(smiles)110print("Num heavy atoms:", molecule.GetNumHeavyAtoms())111molecule112113"""114### Define helper functions115These helper functions will help convert SMILES to graphs and graphs to molecule objects.116117**Representing a molecular graph**. Molecules can naturally be expressed as undirected118graphs `G = (V, E)`, where `V` is a set of vertices (atoms), and `E` a set of edges119(bonds). As for this implementation, each graph (molecule) will be represented as an120adjacency tensor `A`, which encodes existence/non-existence of atom-pairs with their121one-hot encoded bond types stretching an extra dimension, and a feature tensor `H`, which122for each atom, one-hot encodes its atom type. Notice, as hydrogen atoms can be inferred by123RDKit, hydrogen atoms are excluded from `A` and `H` for easier modeling.124125"""126127atom_mapping = {128"C": 0,1290: "C",130"N": 1,1311: "N",132"O": 2,1332: "O",134"F": 3,1353: "F",136}137138bond_mapping = {139"SINGLE": 0,1400: Chem.BondType.SINGLE,141"DOUBLE": 1,1421: Chem.BondType.DOUBLE,143"TRIPLE": 2,1442: Chem.BondType.TRIPLE,145"AROMATIC": 3,1463: Chem.BondType.AROMATIC,147}148149NUM_ATOMS = 9 # Maximum number of atoms150ATOM_DIM = 4 + 1 # Number of atom types151BOND_DIM = 4 + 1 # Number of bond types152LATENT_DIM = 64 # Size of the latent space153154155def smiles_to_graph(smiles):156# Converts SMILES to molecule object157molecule = Chem.MolFromSmiles(smiles)158159# Initialize adjacency and feature tensor160adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")161features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")162163# loop over each atom in molecule164for atom in molecule.GetAtoms():165i = atom.GetIdx()166atom_type = atom_mapping[atom.GetSymbol()]167features[i] = np.eye(ATOM_DIM)[atom_type]168# loop over one-hop neighbors169for neighbor in atom.GetNeighbors():170j = neighbor.GetIdx()171bond = molecule.GetBondBetweenAtoms(i, j)172bond_type_idx = bond_mapping[bond.GetBondType().name]173adjacency[bond_type_idx, [i, j], [j, i]] = 1174175# Where no bond, add 1 to last channel (indicating "non-bond")176# Notice: channels-first177adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1178179# Where no atom, add 1 to last column (indicating "non-atom")180features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1181182return adjacency, features183184185def graph_to_molecule(graph):186# Unpack graph187adjacency, features = graph188189# RWMol is a molecule object intended to be edited190molecule = Chem.RWMol()191192# Remove "no atoms" & atoms with no bonds193keep_idx = np.where(194(np.argmax(features, axis=1) != ATOM_DIM - 1)195& (np.sum(adjacency[:-1], axis=(0, 1)) != 0)196)[0]197features = features[keep_idx]198adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]199200# Add atoms to molecule201for atom_type_idx in np.argmax(features, axis=1):202atom = Chem.Atom(atom_mapping[atom_type_idx])203_ = molecule.AddAtom(atom)204205# Add bonds between atoms in molecule; based on the upper triangles206# of the [symmetric] adjacency tensor207(bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)208for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):209if atom_i == atom_j or bond_ij == BOND_DIM - 1:210continue211bond_type = bond_mapping[bond_ij]212molecule.AddBond(int(atom_i), int(atom_j), bond_type)213214# Sanitize the molecule; for more information on sanitization, see215# https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization216flag = Chem.SanitizeMol(molecule, catchErrors=True)217# Let's be strict. If sanitization fails, return None218if flag != Chem.SanitizeFlags.SANITIZE_NONE:219return None220221return molecule222223224# Test helper functions225graph_to_molecule(smiles_to_graph(smiles))226227"""228### Generate training set229230To save training time, we'll only use a tenth of the QM9 dataset.231"""232233adjacency_tensor, feature_tensor = [], []234for smiles in data[::10]:235adjacency, features = smiles_to_graph(smiles)236adjacency_tensor.append(adjacency)237feature_tensor.append(features)238239adjacency_tensor = np.array(adjacency_tensor)240feature_tensor = np.array(feature_tensor)241242print("adjacency_tensor.shape =", adjacency_tensor.shape)243print("feature_tensor.shape =", feature_tensor.shape)244245"""246## Model247248The idea is to implement a generator network and a discriminator network via WGAN-GP,249that will result in a generator network that can generate small novel molecules250(small graphs).251252The generator network needs to be able to map (for each example in the batch) a vector `z`253to a 3-D adjacency tensor (`A`) and 2-D feature tensor (`H`). For this, `z` will first be254passed through a fully-connected network, for which the output will be further passed255through two separate fully-connected networks. Each of these two fully-connected256networks will then output (for each example in the batch) a tanh-activated vector257followed by a reshape and softmax to match that of a multi-dimensional adjacency/feature258tensor.259260As the discriminator network will receives as input a graph (`A`, `H`) from either the261generator or from the training set, we'll need to implement graph convolutional layers,262which allows us to operate on graphs. This means that input to the discriminator network263will first pass through graph convolutional layers, then an average-pooling layer,264and finally a few fully-connected layers. The final output should be a scalar (for each265example in the batch) which indicates the "realness" of the associated input266(in this case a "fake" or "real" molecule).267268269### Graph generator270"""271272273def GraphGenerator(274dense_units,275dropout_rate,276latent_dim,277adjacency_shape,278feature_shape,279):280z = keras.layers.Input(shape=(LATENT_DIM,))281# Propagate through one or more densely connected layers282x = z283for units in dense_units:284x = keras.layers.Dense(units, activation="tanh")(x)285x = keras.layers.Dropout(dropout_rate)(x)286287# Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)288x_adjacency = keras.layers.Dense(tf.math.reduce_prod(adjacency_shape))(x)289x_adjacency = keras.layers.Reshape(adjacency_shape)(x_adjacency)290# Symmetrify tensors in the last two dimensions291x_adjacency = (x_adjacency + tf.transpose(x_adjacency, (0, 1, 3, 2))) / 2292x_adjacency = keras.layers.Softmax(axis=1)(x_adjacency)293294# Map outputs of previous layer (x) to [continuous] feature tensors (x_features)295x_features = keras.layers.Dense(tf.math.reduce_prod(feature_shape))(x)296x_features = keras.layers.Reshape(feature_shape)(x_features)297x_features = keras.layers.Softmax(axis=2)(x_features)298299return keras.Model(inputs=z, outputs=[x_adjacency, x_features], name="Generator")300301302generator = GraphGenerator(303dense_units=[128, 256, 512],304dropout_rate=0.2,305latent_dim=LATENT_DIM,306adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),307feature_shape=(NUM_ATOMS, ATOM_DIM),308)309generator.summary()310311"""312### Graph discriminator313314315**Graph convolutional layer**. The316[relational graph convolutional layers](https://arxiv.org/abs/1703.06103) implements non-linearly transformed317neighborhood aggregations. We can define these layers as follows:318319`H^{l+1} = σ(D^{-1} @ A @ H^{l+1} @ W^{l})`320321322Where `σ` denotes the non-linear transformation (commonly a ReLU activation), `A` the323adjacency tensor, `H^{l}` the feature tensor at the `l:th` layer, `D^{-1}` the inverse324diagonal degree tensor of `A`, and `W^{l}` the trainable weight tensor at the `l:th`325layer. Specifically, for each bond type (relation), the degree tensor expresses, in the326diagonal, the number of bonds attached to each atom. Notice, in this tutorial `D^{-1}` is327omitted, for two reasons: (1) it's not obvious how to apply this normalization on the328continuous adjacency tensors (generated by the generator), and (2) the performance of the329WGAN without normalization seems to work just fine. Furthermore, in contrast to the330[original paper](https://arxiv.org/abs/1703.06103), no self-loop is defined, as we don't331want to train the generator to predict "self-bonding".332333334335"""336337338class RelationalGraphConvLayer(keras.layers.Layer):339def __init__(340self,341units=128,342activation="relu",343use_bias=False,344kernel_initializer="glorot_uniform",345bias_initializer="zeros",346kernel_regularizer=None,347bias_regularizer=None,348**kwargs349):350super().__init__(**kwargs)351352self.units = units353self.activation = keras.activations.get(activation)354self.use_bias = use_bias355self.kernel_initializer = keras.initializers.get(kernel_initializer)356self.bias_initializer = keras.initializers.get(bias_initializer)357self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)358self.bias_regularizer = keras.regularizers.get(bias_regularizer)359360def build(self, input_shape):361bond_dim = input_shape[0][1]362atom_dim = input_shape[1][2]363364self.kernel = self.add_weight(365shape=(bond_dim, atom_dim, self.units),366initializer=self.kernel_initializer,367regularizer=self.kernel_regularizer,368trainable=True,369name="W",370dtype=tf.float32,371)372373if self.use_bias:374self.bias = self.add_weight(375shape=(bond_dim, 1, self.units),376initializer=self.bias_initializer,377regularizer=self.bias_regularizer,378trainable=True,379name="b",380dtype=tf.float32,381)382383self.built = True384385def call(self, inputs, training=False):386adjacency, features = inputs387# Aggregate information from neighbors388x = tf.matmul(adjacency, features[:, None, :, :])389# Apply linear transformation390x = tf.matmul(x, self.kernel)391if self.use_bias:392x += self.bias393# Reduce bond types dim394x_reduced = tf.reduce_sum(x, axis=1)395# Apply non-linear transformation396return self.activation(x_reduced)397398399def GraphDiscriminator(400gconv_units, dense_units, dropout_rate, adjacency_shape, feature_shape401):402adjacency = keras.layers.Input(shape=adjacency_shape)403features = keras.layers.Input(shape=feature_shape)404405# Propagate through one or more graph convolutional layers406features_transformed = features407for units in gconv_units:408features_transformed = RelationalGraphConvLayer(units)(409[adjacency, features_transformed]410)411412# Reduce 2-D representation of molecule to 1-D413x = keras.layers.GlobalAveragePooling1D()(features_transformed)414415# Propagate through one or more densely connected layers416for units in dense_units:417x = keras.layers.Dense(units, activation="relu")(x)418x = keras.layers.Dropout(dropout_rate)(x)419420# For each molecule, output a single scalar value expressing the421# "realness" of the inputted molecule422x_out = keras.layers.Dense(1, dtype="float32")(x)423424return keras.Model(inputs=[adjacency, features], outputs=x_out)425426427discriminator = GraphDiscriminator(428gconv_units=[128, 128, 128, 128],429dense_units=[512, 512],430dropout_rate=0.2,431adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),432feature_shape=(NUM_ATOMS, ATOM_DIM),433)434discriminator.summary()435436"""437### WGAN-GP438"""439440441class GraphWGAN(keras.Model):442def __init__(443self,444generator,445discriminator,446discriminator_steps=1,447generator_steps=1,448gp_weight=10,449**kwargs450):451super().__init__(**kwargs)452self.generator = generator453self.discriminator = discriminator454self.discriminator_steps = discriminator_steps455self.generator_steps = generator_steps456self.gp_weight = gp_weight457self.latent_dim = self.generator.input_shape[-1]458459def compile(self, optimizer_generator, optimizer_discriminator, **kwargs):460super().compile(**kwargs)461self.optimizer_generator = optimizer_generator462self.optimizer_discriminator = optimizer_discriminator463self.metric_generator = keras.metrics.Mean(name="loss_gen")464self.metric_discriminator = keras.metrics.Mean(name="loss_dis")465466def train_step(self, inputs):467if isinstance(inputs[0], tuple):468inputs = inputs[0]469470graph_real = inputs471472self.batch_size = tf.shape(inputs[0])[0]473474# Train the discriminator for one or more steps475for _ in range(self.discriminator_steps):476z = tf.random.normal((self.batch_size, self.latent_dim))477478with tf.GradientTape() as tape:479graph_generated = self.generator(z, training=True)480loss = self._loss_discriminator(graph_real, graph_generated)481482grads = tape.gradient(loss, self.discriminator.trainable_weights)483self.optimizer_discriminator.apply_gradients(484zip(grads, self.discriminator.trainable_weights)485)486self.metric_discriminator.update_state(loss)487488# Train the generator for one or more steps489for _ in range(self.generator_steps):490z = tf.random.normal((self.batch_size, self.latent_dim))491492with tf.GradientTape() as tape:493graph_generated = self.generator(z, training=True)494loss = self._loss_generator(graph_generated)495496grads = tape.gradient(loss, self.generator.trainable_weights)497self.optimizer_generator.apply_gradients(498zip(grads, self.generator.trainable_weights)499)500self.metric_generator.update_state(loss)501502return {m.name: m.result() for m in self.metrics}503504def _loss_discriminator(self, graph_real, graph_generated):505logits_real = self.discriminator(graph_real, training=True)506logits_generated = self.discriminator(graph_generated, training=True)507loss = tf.reduce_mean(logits_generated) - tf.reduce_mean(logits_real)508loss_gp = self._gradient_penalty(graph_real, graph_generated)509return loss + loss_gp * self.gp_weight510511def _loss_generator(self, graph_generated):512logits_generated = self.discriminator(graph_generated, training=True)513return -tf.reduce_mean(logits_generated)514515def _gradient_penalty(self, graph_real, graph_generated):516# Unpack graphs517adjacency_real, features_real = graph_real518adjacency_generated, features_generated = graph_generated519520# Generate interpolated graphs (adjacency_interp and features_interp)521alpha = tf.random.uniform([self.batch_size])522alpha = tf.reshape(alpha, (self.batch_size, 1, 1, 1))523adjacency_interp = (adjacency_real * alpha) + (1 - alpha) * adjacency_generated524alpha = tf.reshape(alpha, (self.batch_size, 1, 1))525features_interp = (features_real * alpha) + (1 - alpha) * features_generated526527# Compute the logits of interpolated graphs528with tf.GradientTape() as tape:529tape.watch(adjacency_interp)530tape.watch(features_interp)531logits = self.discriminator(532[adjacency_interp, features_interp], training=True533)534535# Compute the gradients with respect to the interpolated graphs536grads = tape.gradient(logits, [adjacency_interp, features_interp])537# Compute the gradient penalty538grads_adjacency_penalty = (1 - tf.norm(grads[0], axis=1)) ** 2539grads_features_penalty = (1 - tf.norm(grads[1], axis=2)) ** 2540return tf.reduce_mean(541tf.reduce_mean(grads_adjacency_penalty, axis=(-2, -1))542+ tf.reduce_mean(grads_features_penalty, axis=(-1))543)544545546"""547## Train the model548549To save time (if run on a CPU), we'll only train the model for 10 epochs.550"""551552wgan = GraphWGAN(generator, discriminator, discriminator_steps=1)553554wgan.compile(555optimizer_generator=keras.optimizers.Adam(5e-4),556optimizer_discriminator=keras.optimizers.Adam(5e-4),557)558559wgan.fit([adjacency_tensor, feature_tensor], epochs=10, batch_size=16)560561"""562## Sample novel molecules with the generator563"""564565566def sample(generator, batch_size):567z = tf.random.normal((batch_size, LATENT_DIM))568graph = generator.predict(z)569# obtain one-hot encoded adjacency tensor570adjacency = tf.argmax(graph[0], axis=1)571adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)572# Remove potential self-loops from adjacency573adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))574# obtain one-hot encoded feature tensor575features = tf.argmax(graph[1], axis=2)576features = tf.one_hot(features, depth=ATOM_DIM, axis=2)577return [578graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])579for i in range(batch_size)580]581582583molecules = sample(wgan.generator, batch_size=48)584585MolsToGridImage(586[m for m in molecules if m is not None][:25], molsPerRow=5, subImgSize=(150, 150)587)588589"""590## Concluding thoughts591592**Inspecting the results**. Ten epochs of training seemed enough to generate some decent593looking molecules! Notice, in contrast to the594[MolGAN paper](https://arxiv.org/abs/1805.11973), the uniqueness of the generated595molecules in this tutorial seems really high, which is great!596597**What we've learned, and prospects**. In this tutorial, a generative model for molecular598graphs was successfully implemented, which allowed us to generate novel molecules. In the599future, it would be interesting to implement generative models that can modify existing600molecules (for instance, to optimize solubility or protein-binding of an existing601molecule). For that however, a reconstruction loss would likely be needed, which is602tricky to implement as there's no easy and obvious way to compute similarity between two603molecular graphs.604605Example available on HuggingFace606607| Trained Model | Demo |608| :--: | :--: |609| [](https://huggingface.co/keras-io/wgan-molecular-graphs) | [](https://huggingface.co/spaces/keras-io/Generating-molecular-graphs-by-WGAN-GP) |610"""611612613