Path: blob/master/examples/generative/molecule_generation.py
3507 views
"""1Title: Drug Molecule Generation with VAE2Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)3Date created: 2022/03/104Last modified: 2024/12/175Description: Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we use a Variational Autoencoder to generate molecules for drug discovery.13We use the research papers14[Automatic chemical design using a data-driven continuous representation of molecules](https://arxiv.org/abs/1610.02415)15and [MolGAN: An implicit generative model for small molecular graphs](https://arxiv.org/abs/1805.11973)16as a reference.1718The model described in the paper **Automatic chemical design using a data-driven19continuous representation of molecules** generates new molecules via efficient exploration20of open-ended spaces of chemical compounds. The model consists of21three components: Encoder, Decoder and Predictor. The Encoder converts the discrete22representation of a molecule into a real-valued continuous vector, and the Decoder23converts these continuous vectors back to discrete molecule representations. The24Predictor estimates chemical properties from the latent continuous vector representation25of the molecule. Continuous representations allow the use of gradient-based26optimization to efficiently guide the search for optimized functional compounds.27282930**Figure (a)** - A diagram of the autoencoder used for molecule design, including the31joint property prediction model. Starting from a discrete molecule representation, such32as a SMILES string, the encoder network converts each molecule into a vector in the33latent space, which is effectively a continuous molecule representation. Given a point34in the latent space, the decoder network produces a corresponding SMILES string. A35multilayer perceptron network estimates the value of target properties associated with36each molecule.3738**Figure (b)** - Gradient-based optimization in continuous latent space. After training a39surrogate model `f(z)` to predict the properties of molecules based on their latent40representation `z`, we can optimize `f(z)` with respect to `z` to find new latent41representations expected to match specific desired properties. These new latent42representations can then be decoded into SMILES strings, at which point their properties43can be tested empirically.4445For an explanation and implementation of MolGAN, please refer to the Keras Example46[**WGAN-GP with R-GCN for the generation of small molecular graphs**](https://bit.ly/3pU6zXK) by47Alexander Kensert. Many of the functions used in the present example are from the above Keras example.48"""4950"""51## Setup5253RDKit is an open source toolkit for cheminformatics and machine learning. This toolkit come in handy54if one is into drug discovery domain. In this example, RDKit is used to conveniently55and efficiently transform SMILES to molecule objects, and then from those obtain sets of atoms56and bonds.5758Quoting from59[WGAN-GP with R-GCN for the generation of small molecular graphs](https://keras.io/examples/generative/wgan-graphs/)):6061**"SMILES expresses the structure of a given molecule in the form of an ASCII string.62The SMILES string is a compact encoding which, for smaller molecules, is relatively human-readable.63Encoding molecules as a string both alleviates and facilitates database and/or web searching64of a given molecule. RDKit uses algorithms to accurately transform a given SMILES to65a molecule object, which can then be used to compute a great number of molecular properties/features."**66"""6768"""shell69pip -q install rdkit-pypi==2021.9.470"""7172import os7374os.environ["KERAS_BACKEND"] = "tensorflow"7576import ast7778import pandas as pd79import numpy as np8081import tensorflow as tf82import keras83from keras import layers84from keras import ops8586import matplotlib.pyplot as plt87from rdkit import Chem, RDLogger88from rdkit.Chem import BondType89from rdkit.Chem.Draw import MolsToGridImage9091RDLogger.DisableLog("rdApp.*")9293"""94## Dataset9596We use the [**ZINC – A Free Database of Commercially Available Compounds for97Virtual Screening**](https://bit.ly/3IVBI4x) dataset. The dataset comes with molecule98formula in SMILE representation along with their respective molecular properties such as99**logP** (water–octanal partition coefficient), **SAS** (synthetic100accessibility score) and **QED** (Qualitative Estimate of Drug-likeness).101102"""103104csv_path = keras.utils.get_file(105"250k_rndm_zinc_drugs_clean_3.csv",106"https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv",107)108109df = pd.read_csv(csv_path)110df["smiles"] = df["smiles"].apply(lambda s: s.replace("\n", ""))111df.head()112113"""114## Hyperparameters115"""116117SMILE_CHARSET = '["C", "B", "F", "I", "H", "O", "N", "S", "P", "Cl", "Br"]'118119bond_mapping = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3}120bond_mapping.update(121{0: BondType.SINGLE, 1: BondType.DOUBLE, 2: BondType.TRIPLE, 3: BondType.AROMATIC}122)123SMILE_CHARSET = ast.literal_eval(SMILE_CHARSET)124125MAX_MOLSIZE = max(df["smiles"].str.len())126SMILE_to_index = dict((c, i) for i, c in enumerate(SMILE_CHARSET))127index_to_SMILE = dict((i, c) for i, c in enumerate(SMILE_CHARSET))128atom_mapping = dict(SMILE_to_index)129atom_mapping.update(index_to_SMILE)130131BATCH_SIZE = 100132EPOCHS = 10133134VAE_LR = 5e-4135NUM_ATOMS = 120 # Maximum number of atoms136137ATOM_DIM = len(SMILE_CHARSET) # Number of atom types138BOND_DIM = 4 + 1 # Number of bond types139LATENT_DIM = 435 # Size of the latent space140141142def smiles_to_graph(smiles):143# Converts SMILES to molecule object144molecule = Chem.MolFromSmiles(smiles)145146# Initialize adjacency and feature tensor147adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")148features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")149150# loop over each atom in molecule151for atom in molecule.GetAtoms():152i = atom.GetIdx()153atom_type = atom_mapping[atom.GetSymbol()]154features[i] = np.eye(ATOM_DIM)[atom_type]155# loop over one-hop neighbors156for neighbor in atom.GetNeighbors():157j = neighbor.GetIdx()158bond = molecule.GetBondBetweenAtoms(i, j)159bond_type_idx = bond_mapping[bond.GetBondType().name]160adjacency[bond_type_idx, [i, j], [j, i]] = 1161162# Where no bond, add 1 to last channel (indicating "non-bond")163# Notice: channels-first164adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1165166# Where no atom, add 1 to last column (indicating "non-atom")167features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1168169return adjacency, features170171172def graph_to_molecule(graph):173# Unpack graph174adjacency, features = graph175176# RWMol is a molecule object intended to be edited177molecule = Chem.RWMol()178179# Remove "no atoms" & atoms with no bonds180keep_idx = np.where(181(np.argmax(features, axis=1) != ATOM_DIM - 1)182& (np.sum(adjacency[:-1], axis=(0, 1)) != 0)183)[0]184features = features[keep_idx]185adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]186187# Add atoms to molecule188for atom_type_idx in np.argmax(features, axis=1):189atom = Chem.Atom(atom_mapping[atom_type_idx])190_ = molecule.AddAtom(atom)191192# Add bonds between atoms in molecule; based on the upper triangles193# of the [symmetric] adjacency tensor194(bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)195for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):196if atom_i == atom_j or bond_ij == BOND_DIM - 1:197continue198bond_type = bond_mapping[bond_ij]199molecule.AddBond(int(atom_i), int(atom_j), bond_type)200201# Sanitize the molecule; for more information on sanitization, see202# https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization203flag = Chem.SanitizeMol(molecule, catchErrors=True)204# Let's be strict. If sanitization fails, return None205if flag != Chem.SanitizeFlags.SANITIZE_NONE:206return None207208return molecule209210211"""212## Generate training set213"""214215train_df = df.sample(frac=0.75, random_state=42) # random state is a seed value216train_df.reset_index(drop=True, inplace=True)217218adjacency_tensor, feature_tensor, qed_tensor = [], [], []219for idx in range(8000):220adjacency, features = smiles_to_graph(train_df.loc[idx]["smiles"])221qed = train_df.loc[idx]["qed"]222adjacency_tensor.append(adjacency)223feature_tensor.append(features)224qed_tensor.append(qed)225226adjacency_tensor = np.array(adjacency_tensor)227feature_tensor = np.array(feature_tensor)228qed_tensor = np.array(qed_tensor)229230231class RelationalGraphConvLayer(keras.layers.Layer):232def __init__(233self,234units=128,235activation="relu",236use_bias=False,237kernel_initializer="glorot_uniform",238bias_initializer="zeros",239kernel_regularizer=None,240bias_regularizer=None,241**kwargs242):243super().__init__(**kwargs)244245self.units = units246self.activation = keras.activations.get(activation)247self.use_bias = use_bias248self.kernel_initializer = keras.initializers.get(kernel_initializer)249self.bias_initializer = keras.initializers.get(bias_initializer)250self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)251self.bias_regularizer = keras.regularizers.get(bias_regularizer)252253def build(self, input_shape):254bond_dim = input_shape[0][1]255atom_dim = input_shape[1][2]256257self.kernel = self.add_weight(258shape=(bond_dim, atom_dim, self.units),259initializer=self.kernel_initializer,260regularizer=self.kernel_regularizer,261trainable=True,262name="W",263dtype="float32",264)265266if self.use_bias:267self.bias = self.add_weight(268shape=(bond_dim, 1, self.units),269initializer=self.bias_initializer,270regularizer=self.bias_regularizer,271trainable=True,272name="b",273dtype="float32",274)275276self.built = True277278def call(self, inputs, training=False):279adjacency, features = inputs280# Aggregate information from neighbors281x = ops.matmul(adjacency, features[:, None])282# Apply linear transformation283x = ops.matmul(x, self.kernel)284if self.use_bias:285x += self.bias286# Reduce bond types dim287x_reduced = ops.sum(x, axis=1)288# Apply non-linear transformation289return self.activation(x_reduced)290291292"""293## Build the Encoder and Decoder294295The Encoder takes as input a molecule's graph adjacency matrix and feature matrix.296These features are processed via a Graph Convolution layer, then are flattened and297processed by several Dense layers to derive `z_mean` and `log_var`, the298latent-space representation of the molecule.299300**Graph Convolution layer**: The relational graph convolution layer implements301non-linearly transformed neighbourhood aggregations. We can define these layers as302follows:303304`H_hat**(l+1) = σ(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`305306Where `σ` denotes the non-linear transformation (commonly a ReLU activation), `A` the307adjacency tensor, `H_hat**(l)` the feature tensor at the `l-th` layer, `D_hat**(-1)` the308inverse diagonal degree tensor of `A_hat`, and `W_hat**(l)` the trainable weight tensor309at the `l-th` layer. Specifically, for each bond type (relation), the degree tensor310expresses, in the diagonal, the number of bonds attached to each atom.311312Source:313[WGAN-GP with R-GCN for the generation of small molecular graphs](https://keras.io/examples/generative/wgan-graphs/))314315The Decoder takes as input the latent-space representation and predicts316the graph adjacency matrix and feature matrix of the corresponding molecules.317"""318319320def get_encoder(321gconv_units, latent_dim, adjacency_shape, feature_shape, dense_units, dropout_rate322):323adjacency = layers.Input(shape=adjacency_shape)324features = layers.Input(shape=feature_shape)325326# Propagate through one or more graph convolutional layers327features_transformed = features328for units in gconv_units:329features_transformed = RelationalGraphConvLayer(units)(330[adjacency, features_transformed]331)332# Reduce 2-D representation of molecule to 1-D333x = layers.GlobalAveragePooling1D()(features_transformed)334335# Propagate through one or more densely connected layers336for units in dense_units:337x = layers.Dense(units, activation="relu")(x)338x = layers.Dropout(dropout_rate)(x)339340z_mean = layers.Dense(latent_dim, dtype="float32", name="z_mean")(x)341log_var = layers.Dense(latent_dim, dtype="float32", name="log_var")(x)342343encoder = keras.Model([adjacency, features], [z_mean, log_var], name="encoder")344345return encoder346347348def get_decoder(dense_units, dropout_rate, latent_dim, adjacency_shape, feature_shape):349latent_inputs = keras.Input(shape=(latent_dim,))350351x = latent_inputs352for units in dense_units:353x = layers.Dense(units, activation="tanh")(x)354x = layers.Dropout(dropout_rate)(x)355356# Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)357x_adjacency = layers.Dense(np.prod(adjacency_shape))(x)358x_adjacency = layers.Reshape(adjacency_shape)(x_adjacency)359# Symmetrify tensors in the last two dimensions360x_adjacency = (x_adjacency + ops.transpose(x_adjacency, (0, 1, 3, 2))) / 2361x_adjacency = layers.Softmax(axis=1)(x_adjacency)362363# Map outputs of previous layer (x) to [continuous] feature tensors (x_features)364x_features = layers.Dense(np.prod(feature_shape))(x)365x_features = layers.Reshape(feature_shape)(x_features)366x_features = layers.Softmax(axis=2)(x_features)367368decoder = keras.Model(369latent_inputs, outputs=[x_adjacency, x_features], name="decoder"370)371372return decoder373374375"""376## Build the Sampling layer377"""378379380class Sampling(layers.Layer):381def __init__(self, seed=None, **kwargs):382super().__init__(**kwargs)383self.seed_generator = keras.random.SeedGenerator(seed)384385def call(self, inputs):386z_mean, z_log_var = inputs387batch, dim = ops.shape(z_log_var)388epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)389return z_mean + ops.exp(0.5 * z_log_var) * epsilon390391392"""393## Build the VAE394395This model is trained to optimize four losses:396397* Categorical crossentropy398* KL divergence loss399* Property prediction loss400* Graph loss (gradient penalty)401402The categorical crossentropy loss function measures the model's403reconstruction accuracy. The Property prediction loss estimates the mean squared404error between predicted and actual properties after running the latent representation405through a property prediction model. The property406prediction of the model is optimized via binary crossentropy. The gradient407penalty is further guided by the model's property (QED) prediction.408409A gradient penalty is an alternative soft constraint on the4101-Lipschitz continuity as an improvement upon the gradient clipping scheme from the411original neural network412("1-Lipschitz continuity" means that the norm of the gradient is at most 1 at every single413point of the function).414It adds a regularization term to the loss function.415"""416417418class MoleculeGenerator(keras.Model):419def __init__(self, encoder, decoder, max_len, seed=None, **kwargs):420super().__init__(**kwargs)421self.encoder = encoder422self.decoder = decoder423self.property_prediction_layer = layers.Dense(1)424self.max_len = max_len425self.seed_generator = keras.random.SeedGenerator(seed)426self.sampling_layer = Sampling(seed=seed)427428self.train_total_loss_tracker = keras.metrics.Mean(name="train_total_loss")429self.val_total_loss_tracker = keras.metrics.Mean(name="val_total_loss")430431def train_step(self, data):432adjacency_tensor, feature_tensor, qed_tensor = data[0]433graph_real = [adjacency_tensor, feature_tensor]434self.batch_size = ops.shape(qed_tensor)[0]435with tf.GradientTape() as tape:436z_mean, z_log_var, qed_pred, gen_adjacency, gen_features = self(437graph_real, training=True438)439graph_generated = [gen_adjacency, gen_features]440total_loss = self._compute_loss(441z_log_var, z_mean, qed_tensor, qed_pred, graph_real, graph_generated442)443444grads = tape.gradient(total_loss, self.trainable_weights)445self.optimizer.apply_gradients(zip(grads, self.trainable_weights))446447self.train_total_loss_tracker.update_state(total_loss)448return {"loss": self.train_total_loss_tracker.result()}449450def _compute_loss(451self, z_log_var, z_mean, qed_true, qed_pred, graph_real, graph_generated452):453adjacency_real, features_real = graph_real454adjacency_gen, features_gen = graph_generated455456adjacency_loss = ops.mean(457ops.sum(458keras.losses.categorical_crossentropy(459adjacency_real, adjacency_gen, axis=1460),461axis=(1, 2),462)463)464features_loss = ops.mean(465ops.sum(466keras.losses.categorical_crossentropy(features_real, features_gen),467axis=(1),468)469)470kl_loss = -0.5 * ops.sum(4711 + z_log_var - z_mean**2 - ops.minimum(ops.exp(z_log_var), 1e6), 1472)473kl_loss = ops.mean(kl_loss)474475property_loss = ops.mean(476keras.losses.binary_crossentropy(qed_true, ops.squeeze(qed_pred, axis=1))477)478479graph_loss = self._gradient_penalty(graph_real, graph_generated)480481return kl_loss + property_loss + graph_loss + adjacency_loss + features_loss482483def _gradient_penalty(self, graph_real, graph_generated):484# Unpack graphs485adjacency_real, features_real = graph_real486adjacency_generated, features_generated = graph_generated487488# Generate interpolated graphs (adjacency_interp and features_interp)489alpha = keras.random.uniform(shape=(self.batch_size,), seed=self.seed_generator)490alpha = ops.reshape(alpha, (self.batch_size, 1, 1, 1))491adjacency_interp = (adjacency_real * alpha) + (4921.0 - alpha493) * adjacency_generated494alpha = ops.reshape(alpha, (self.batch_size, 1, 1))495features_interp = (features_real * alpha) + (1.0 - alpha) * features_generated496497# Compute the logits of interpolated graphs498with tf.GradientTape() as tape:499tape.watch(adjacency_interp)500tape.watch(features_interp)501_, _, logits, _, _ = self(502[adjacency_interp, features_interp], training=True503)504505# Compute the gradients with respect to the interpolated graphs506grads = tape.gradient(logits, [adjacency_interp, features_interp])507# Compute the gradient penalty508grads_adjacency_penalty = (1 - ops.norm(grads[0], axis=1)) ** 2509grads_features_penalty = (1 - ops.norm(grads[1], axis=2)) ** 2510return ops.mean(511ops.mean(grads_adjacency_penalty, axis=(-2, -1))512+ ops.mean(grads_features_penalty, axis=(-1))513)514515def inference(self, batch_size):516z = keras.random.normal(517shape=(batch_size, LATENT_DIM), seed=self.seed_generator518)519reconstruction_adjacency, reconstruction_features = model.decoder.predict(z)520# obtain one-hot encoded adjacency tensor521adjacency = ops.argmax(reconstruction_adjacency, axis=1)522adjacency = ops.one_hot(adjacency, num_classes=BOND_DIM, axis=1)523# Remove potential self-loops from adjacency524adjacency = adjacency * (1.0 - ops.eye(NUM_ATOMS, dtype="float32")[None, None])525# obtain one-hot encoded feature tensor526features = ops.argmax(reconstruction_features, axis=2)527features = ops.one_hot(features, num_classes=ATOM_DIM, axis=2)528return [529graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])530for i in range(batch_size)531]532533def call(self, inputs):534z_mean, log_var = self.encoder(inputs)535z = self.sampling_layer([z_mean, log_var])536537gen_adjacency, gen_features = self.decoder(z)538539property_pred = self.property_prediction_layer(z_mean)540541return z_mean, log_var, property_pred, gen_adjacency, gen_features542543544"""545## Train the model546"""547548vae_optimizer = keras.optimizers.Adam(learning_rate=VAE_LR)549550encoder = get_encoder(551gconv_units=[9],552adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),553feature_shape=(NUM_ATOMS, ATOM_DIM),554latent_dim=LATENT_DIM,555dense_units=[512],556dropout_rate=0.0,557)558decoder = get_decoder(559dense_units=[128, 256, 512],560dropout_rate=0.2,561latent_dim=LATENT_DIM,562adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),563feature_shape=(NUM_ATOMS, ATOM_DIM),564)565566model = MoleculeGenerator(encoder, decoder, MAX_MOLSIZE)567568model.compile(vae_optimizer)569history = model.fit([adjacency_tensor, feature_tensor, qed_tensor], epochs=EPOCHS)570571"""572## Inference573574We use our model to generate new valid molecules from different points of the latent space.575"""576577"""578### Generate unique Molecules with the model579"""580581molecules = model.inference(1000)582583MolsToGridImage(584[m for m in molecules if m is not None][:1000], molsPerRow=5, subImgSize=(260, 160)585)586587"""588### Display latent space clusters with respect to molecular properties (QAE)589"""590591592def plot_latent(vae, data, labels):593# display a 2D plot of the property in the latent space594z_mean, _ = vae.encoder.predict(data)595plt.figure(figsize=(12, 10))596plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)597plt.colorbar()598plt.xlabel("z[0]")599plt.ylabel("z[1]")600plt.show()601602603plot_latent(model, [adjacency_tensor[:8000], feature_tensor[:8000]], qed_tensor[:8000])604605"""606## Conclusion607608In this example, we combined model architectures from two papers,609"Automatic chemical design using a data-driven continuous representation of610molecules" from 2016 and the "MolGAN" paper from 2018. The former paper611treats SMILES inputs as strings and seeks to generate molecule strings in SMILES format,612while the later paper considers SMILES inputs as graphs (a combination of adjacency613matrices and feature matrices) and seeks to generate molecules as graphs.614615This hybrid approach enables a new type of directed gradient-based search through chemical space.616617Example available on HuggingFace618619| Trained Model | Demo |620| :--: | :--: |621| [](https://huggingface.co/keras-io/drug-molecule-generation-with-VAE) | [](https://huggingface.co/spaces/keras-io/generating-drug-molecule-with-VAE) |622"""623624625