Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/molecule_generation.py
3507 views
1
"""
2
Title: Drug Molecule Generation with VAE
3
Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
4
Date created: 2022/03/10
5
Last modified: 2024/12/17
6
Description: Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we use a Variational Autoencoder to generate molecules for drug discovery.
14
We use the research papers
15
[Automatic chemical design using a data-driven continuous representation of molecules](https://arxiv.org/abs/1610.02415)
16
and [MolGAN: An implicit generative model for small molecular graphs](https://arxiv.org/abs/1805.11973)
17
as a reference.
18
19
The model described in the paper **Automatic chemical design using a data-driven
20
continuous representation of molecules** generates new molecules via efficient exploration
21
of open-ended spaces of chemical compounds. The model consists of
22
three components: Encoder, Decoder and Predictor. The Encoder converts the discrete
23
representation of a molecule into a real-valued continuous vector, and the Decoder
24
converts these continuous vectors back to discrete molecule representations. The
25
Predictor estimates chemical properties from the latent continuous vector representation
26
of the molecule. Continuous representations allow the use of gradient-based
27
optimization to efficiently guide the search for optimized functional compounds.
28
29
![intro](https://bit.ly/3CtPMzM)
30
31
**Figure (a)** - A diagram of the autoencoder used for molecule design, including the
32
joint property prediction model. Starting from a discrete molecule representation, such
33
as a SMILES string, the encoder network converts each molecule into a vector in the
34
latent space, which is effectively a continuous molecule representation. Given a point
35
in the latent space, the decoder network produces a corresponding SMILES string. A
36
multilayer perceptron network estimates the value of target properties associated with
37
each molecule.
38
39
**Figure (b)** - Gradient-based optimization in continuous latent space. After training a
40
surrogate model `f(z)` to predict the properties of molecules based on their latent
41
representation `z`, we can optimize `f(z)` with respect to `z` to find new latent
42
representations expected to match specific desired properties. These new latent
43
representations can then be decoded into SMILES strings, at which point their properties
44
can be tested empirically.
45
46
For an explanation and implementation of MolGAN, please refer to the Keras Example
47
[**WGAN-GP with R-GCN for the generation of small molecular graphs**](https://bit.ly/3pU6zXK) by
48
Alexander Kensert. Many of the functions used in the present example are from the above Keras example.
49
"""
50
51
"""
52
## Setup
53
54
RDKit is an open source toolkit for cheminformatics and machine learning. This toolkit come in handy
55
if one is into drug discovery domain. In this example, RDKit is used to conveniently
56
and efficiently transform SMILES to molecule objects, and then from those obtain sets of atoms
57
and bonds.
58
59
Quoting from
60
[WGAN-GP with R-GCN for the generation of small molecular graphs](https://keras.io/examples/generative/wgan-graphs/)):
61
62
**"SMILES expresses the structure of a given molecule in the form of an ASCII string.
63
The SMILES string is a compact encoding which, for smaller molecules, is relatively human-readable.
64
Encoding molecules as a string both alleviates and facilitates database and/or web searching
65
of a given molecule. RDKit uses algorithms to accurately transform a given SMILES to
66
a molecule object, which can then be used to compute a great number of molecular properties/features."**
67
"""
68
69
"""shell
70
pip -q install rdkit-pypi==2021.9.4
71
"""
72
73
import os
74
75
os.environ["KERAS_BACKEND"] = "tensorflow"
76
77
import ast
78
79
import pandas as pd
80
import numpy as np
81
82
import tensorflow as tf
83
import keras
84
from keras import layers
85
from keras import ops
86
87
import matplotlib.pyplot as plt
88
from rdkit import Chem, RDLogger
89
from rdkit.Chem import BondType
90
from rdkit.Chem.Draw import MolsToGridImage
91
92
RDLogger.DisableLog("rdApp.*")
93
94
"""
95
## Dataset
96
97
We use the [**ZINC – A Free Database of Commercially Available Compounds for
98
Virtual Screening**](https://bit.ly/3IVBI4x) dataset. The dataset comes with molecule
99
formula in SMILE representation along with their respective molecular properties such as
100
**logP** (water–octanal partition coefficient), **SAS** (synthetic
101
accessibility score) and **QED** (Qualitative Estimate of Drug-likeness).
102
103
"""
104
105
csv_path = keras.utils.get_file(
106
"250k_rndm_zinc_drugs_clean_3.csv",
107
"https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv",
108
)
109
110
df = pd.read_csv(csv_path)
111
df["smiles"] = df["smiles"].apply(lambda s: s.replace("\n", ""))
112
df.head()
113
114
"""
115
## Hyperparameters
116
"""
117
118
SMILE_CHARSET = '["C", "B", "F", "I", "H", "O", "N", "S", "P", "Cl", "Br"]'
119
120
bond_mapping = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3}
121
bond_mapping.update(
122
{0: BondType.SINGLE, 1: BondType.DOUBLE, 2: BondType.TRIPLE, 3: BondType.AROMATIC}
123
)
124
SMILE_CHARSET = ast.literal_eval(SMILE_CHARSET)
125
126
MAX_MOLSIZE = max(df["smiles"].str.len())
127
SMILE_to_index = dict((c, i) for i, c in enumerate(SMILE_CHARSET))
128
index_to_SMILE = dict((i, c) for i, c in enumerate(SMILE_CHARSET))
129
atom_mapping = dict(SMILE_to_index)
130
atom_mapping.update(index_to_SMILE)
131
132
BATCH_SIZE = 100
133
EPOCHS = 10
134
135
VAE_LR = 5e-4
136
NUM_ATOMS = 120 # Maximum number of atoms
137
138
ATOM_DIM = len(SMILE_CHARSET) # Number of atom types
139
BOND_DIM = 4 + 1 # Number of bond types
140
LATENT_DIM = 435 # Size of the latent space
141
142
143
def smiles_to_graph(smiles):
144
# Converts SMILES to molecule object
145
molecule = Chem.MolFromSmiles(smiles)
146
147
# Initialize adjacency and feature tensor
148
adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")
149
features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")
150
151
# loop over each atom in molecule
152
for atom in molecule.GetAtoms():
153
i = atom.GetIdx()
154
atom_type = atom_mapping[atom.GetSymbol()]
155
features[i] = np.eye(ATOM_DIM)[atom_type]
156
# loop over one-hop neighbors
157
for neighbor in atom.GetNeighbors():
158
j = neighbor.GetIdx()
159
bond = molecule.GetBondBetweenAtoms(i, j)
160
bond_type_idx = bond_mapping[bond.GetBondType().name]
161
adjacency[bond_type_idx, [i, j], [j, i]] = 1
162
163
# Where no bond, add 1 to last channel (indicating "non-bond")
164
# Notice: channels-first
165
adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1
166
167
# Where no atom, add 1 to last column (indicating "non-atom")
168
features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1
169
170
return adjacency, features
171
172
173
def graph_to_molecule(graph):
174
# Unpack graph
175
adjacency, features = graph
176
177
# RWMol is a molecule object intended to be edited
178
molecule = Chem.RWMol()
179
180
# Remove "no atoms" & atoms with no bonds
181
keep_idx = np.where(
182
(np.argmax(features, axis=1) != ATOM_DIM - 1)
183
& (np.sum(adjacency[:-1], axis=(0, 1)) != 0)
184
)[0]
185
features = features[keep_idx]
186
adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]
187
188
# Add atoms to molecule
189
for atom_type_idx in np.argmax(features, axis=1):
190
atom = Chem.Atom(atom_mapping[atom_type_idx])
191
_ = molecule.AddAtom(atom)
192
193
# Add bonds between atoms in molecule; based on the upper triangles
194
# of the [symmetric] adjacency tensor
195
(bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)
196
for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):
197
if atom_i == atom_j or bond_ij == BOND_DIM - 1:
198
continue
199
bond_type = bond_mapping[bond_ij]
200
molecule.AddBond(int(atom_i), int(atom_j), bond_type)
201
202
# Sanitize the molecule; for more information on sanitization, see
203
# https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
204
flag = Chem.SanitizeMol(molecule, catchErrors=True)
205
# Let's be strict. If sanitization fails, return None
206
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
207
return None
208
209
return molecule
210
211
212
"""
213
## Generate training set
214
"""
215
216
train_df = df.sample(frac=0.75, random_state=42) # random state is a seed value
217
train_df.reset_index(drop=True, inplace=True)
218
219
adjacency_tensor, feature_tensor, qed_tensor = [], [], []
220
for idx in range(8000):
221
adjacency, features = smiles_to_graph(train_df.loc[idx]["smiles"])
222
qed = train_df.loc[idx]["qed"]
223
adjacency_tensor.append(adjacency)
224
feature_tensor.append(features)
225
qed_tensor.append(qed)
226
227
adjacency_tensor = np.array(adjacency_tensor)
228
feature_tensor = np.array(feature_tensor)
229
qed_tensor = np.array(qed_tensor)
230
231
232
class RelationalGraphConvLayer(keras.layers.Layer):
233
def __init__(
234
self,
235
units=128,
236
activation="relu",
237
use_bias=False,
238
kernel_initializer="glorot_uniform",
239
bias_initializer="zeros",
240
kernel_regularizer=None,
241
bias_regularizer=None,
242
**kwargs
243
):
244
super().__init__(**kwargs)
245
246
self.units = units
247
self.activation = keras.activations.get(activation)
248
self.use_bias = use_bias
249
self.kernel_initializer = keras.initializers.get(kernel_initializer)
250
self.bias_initializer = keras.initializers.get(bias_initializer)
251
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
252
self.bias_regularizer = keras.regularizers.get(bias_regularizer)
253
254
def build(self, input_shape):
255
bond_dim = input_shape[0][1]
256
atom_dim = input_shape[1][2]
257
258
self.kernel = self.add_weight(
259
shape=(bond_dim, atom_dim, self.units),
260
initializer=self.kernel_initializer,
261
regularizer=self.kernel_regularizer,
262
trainable=True,
263
name="W",
264
dtype="float32",
265
)
266
267
if self.use_bias:
268
self.bias = self.add_weight(
269
shape=(bond_dim, 1, self.units),
270
initializer=self.bias_initializer,
271
regularizer=self.bias_regularizer,
272
trainable=True,
273
name="b",
274
dtype="float32",
275
)
276
277
self.built = True
278
279
def call(self, inputs, training=False):
280
adjacency, features = inputs
281
# Aggregate information from neighbors
282
x = ops.matmul(adjacency, features[:, None])
283
# Apply linear transformation
284
x = ops.matmul(x, self.kernel)
285
if self.use_bias:
286
x += self.bias
287
# Reduce bond types dim
288
x_reduced = ops.sum(x, axis=1)
289
# Apply non-linear transformation
290
return self.activation(x_reduced)
291
292
293
"""
294
## Build the Encoder and Decoder
295
296
The Encoder takes as input a molecule's graph adjacency matrix and feature matrix.
297
These features are processed via a Graph Convolution layer, then are flattened and
298
processed by several Dense layers to derive `z_mean` and `log_var`, the
299
latent-space representation of the molecule.
300
301
**Graph Convolution layer**: The relational graph convolution layer implements
302
non-linearly transformed neighbourhood aggregations. We can define these layers as
303
follows:
304
305
`H_hat**(l+1) = σ(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`
306
307
Where `σ` denotes the non-linear transformation (commonly a ReLU activation), `A` the
308
adjacency tensor, `H_hat**(l)` the feature tensor at the `l-th` layer, `D_hat**(-1)` the
309
inverse diagonal degree tensor of `A_hat`, and `W_hat**(l)` the trainable weight tensor
310
at the `l-th` layer. Specifically, for each bond type (relation), the degree tensor
311
expresses, in the diagonal, the number of bonds attached to each atom.
312
313
Source:
314
[WGAN-GP with R-GCN for the generation of small molecular graphs](https://keras.io/examples/generative/wgan-graphs/))
315
316
The Decoder takes as input the latent-space representation and predicts
317
the graph adjacency matrix and feature matrix of the corresponding molecules.
318
"""
319
320
321
def get_encoder(
322
gconv_units, latent_dim, adjacency_shape, feature_shape, dense_units, dropout_rate
323
):
324
adjacency = layers.Input(shape=adjacency_shape)
325
features = layers.Input(shape=feature_shape)
326
327
# Propagate through one or more graph convolutional layers
328
features_transformed = features
329
for units in gconv_units:
330
features_transformed = RelationalGraphConvLayer(units)(
331
[adjacency, features_transformed]
332
)
333
# Reduce 2-D representation of molecule to 1-D
334
x = layers.GlobalAveragePooling1D()(features_transformed)
335
336
# Propagate through one or more densely connected layers
337
for units in dense_units:
338
x = layers.Dense(units, activation="relu")(x)
339
x = layers.Dropout(dropout_rate)(x)
340
341
z_mean = layers.Dense(latent_dim, dtype="float32", name="z_mean")(x)
342
log_var = layers.Dense(latent_dim, dtype="float32", name="log_var")(x)
343
344
encoder = keras.Model([adjacency, features], [z_mean, log_var], name="encoder")
345
346
return encoder
347
348
349
def get_decoder(dense_units, dropout_rate, latent_dim, adjacency_shape, feature_shape):
350
latent_inputs = keras.Input(shape=(latent_dim,))
351
352
x = latent_inputs
353
for units in dense_units:
354
x = layers.Dense(units, activation="tanh")(x)
355
x = layers.Dropout(dropout_rate)(x)
356
357
# Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)
358
x_adjacency = layers.Dense(np.prod(adjacency_shape))(x)
359
x_adjacency = layers.Reshape(adjacency_shape)(x_adjacency)
360
# Symmetrify tensors in the last two dimensions
361
x_adjacency = (x_adjacency + ops.transpose(x_adjacency, (0, 1, 3, 2))) / 2
362
x_adjacency = layers.Softmax(axis=1)(x_adjacency)
363
364
# Map outputs of previous layer (x) to [continuous] feature tensors (x_features)
365
x_features = layers.Dense(np.prod(feature_shape))(x)
366
x_features = layers.Reshape(feature_shape)(x_features)
367
x_features = layers.Softmax(axis=2)(x_features)
368
369
decoder = keras.Model(
370
latent_inputs, outputs=[x_adjacency, x_features], name="decoder"
371
)
372
373
return decoder
374
375
376
"""
377
## Build the Sampling layer
378
"""
379
380
381
class Sampling(layers.Layer):
382
def __init__(self, seed=None, **kwargs):
383
super().__init__(**kwargs)
384
self.seed_generator = keras.random.SeedGenerator(seed)
385
386
def call(self, inputs):
387
z_mean, z_log_var = inputs
388
batch, dim = ops.shape(z_log_var)
389
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
390
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
391
392
393
"""
394
## Build the VAE
395
396
This model is trained to optimize four losses:
397
398
* Categorical crossentropy
399
* KL divergence loss
400
* Property prediction loss
401
* Graph loss (gradient penalty)
402
403
The categorical crossentropy loss function measures the model's
404
reconstruction accuracy. The Property prediction loss estimates the mean squared
405
error between predicted and actual properties after running the latent representation
406
through a property prediction model. The property
407
prediction of the model is optimized via binary crossentropy. The gradient
408
penalty is further guided by the model's property (QED) prediction.
409
410
A gradient penalty is an alternative soft constraint on the
411
1-Lipschitz continuity as an improvement upon the gradient clipping scheme from the
412
original neural network
413
("1-Lipschitz continuity" means that the norm of the gradient is at most 1 at every single
414
point of the function).
415
It adds a regularization term to the loss function.
416
"""
417
418
419
class MoleculeGenerator(keras.Model):
420
def __init__(self, encoder, decoder, max_len, seed=None, **kwargs):
421
super().__init__(**kwargs)
422
self.encoder = encoder
423
self.decoder = decoder
424
self.property_prediction_layer = layers.Dense(1)
425
self.max_len = max_len
426
self.seed_generator = keras.random.SeedGenerator(seed)
427
self.sampling_layer = Sampling(seed=seed)
428
429
self.train_total_loss_tracker = keras.metrics.Mean(name="train_total_loss")
430
self.val_total_loss_tracker = keras.metrics.Mean(name="val_total_loss")
431
432
def train_step(self, data):
433
adjacency_tensor, feature_tensor, qed_tensor = data[0]
434
graph_real = [adjacency_tensor, feature_tensor]
435
self.batch_size = ops.shape(qed_tensor)[0]
436
with tf.GradientTape() as tape:
437
z_mean, z_log_var, qed_pred, gen_adjacency, gen_features = self(
438
graph_real, training=True
439
)
440
graph_generated = [gen_adjacency, gen_features]
441
total_loss = self._compute_loss(
442
z_log_var, z_mean, qed_tensor, qed_pred, graph_real, graph_generated
443
)
444
445
grads = tape.gradient(total_loss, self.trainable_weights)
446
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
447
448
self.train_total_loss_tracker.update_state(total_loss)
449
return {"loss": self.train_total_loss_tracker.result()}
450
451
def _compute_loss(
452
self, z_log_var, z_mean, qed_true, qed_pred, graph_real, graph_generated
453
):
454
adjacency_real, features_real = graph_real
455
adjacency_gen, features_gen = graph_generated
456
457
adjacency_loss = ops.mean(
458
ops.sum(
459
keras.losses.categorical_crossentropy(
460
adjacency_real, adjacency_gen, axis=1
461
),
462
axis=(1, 2),
463
)
464
)
465
features_loss = ops.mean(
466
ops.sum(
467
keras.losses.categorical_crossentropy(features_real, features_gen),
468
axis=(1),
469
)
470
)
471
kl_loss = -0.5 * ops.sum(
472
1 + z_log_var - z_mean**2 - ops.minimum(ops.exp(z_log_var), 1e6), 1
473
)
474
kl_loss = ops.mean(kl_loss)
475
476
property_loss = ops.mean(
477
keras.losses.binary_crossentropy(qed_true, ops.squeeze(qed_pred, axis=1))
478
)
479
480
graph_loss = self._gradient_penalty(graph_real, graph_generated)
481
482
return kl_loss + property_loss + graph_loss + adjacency_loss + features_loss
483
484
def _gradient_penalty(self, graph_real, graph_generated):
485
# Unpack graphs
486
adjacency_real, features_real = graph_real
487
adjacency_generated, features_generated = graph_generated
488
489
# Generate interpolated graphs (adjacency_interp and features_interp)
490
alpha = keras.random.uniform(shape=(self.batch_size,), seed=self.seed_generator)
491
alpha = ops.reshape(alpha, (self.batch_size, 1, 1, 1))
492
adjacency_interp = (adjacency_real * alpha) + (
493
1.0 - alpha
494
) * adjacency_generated
495
alpha = ops.reshape(alpha, (self.batch_size, 1, 1))
496
features_interp = (features_real * alpha) + (1.0 - alpha) * features_generated
497
498
# Compute the logits of interpolated graphs
499
with tf.GradientTape() as tape:
500
tape.watch(adjacency_interp)
501
tape.watch(features_interp)
502
_, _, logits, _, _ = self(
503
[adjacency_interp, features_interp], training=True
504
)
505
506
# Compute the gradients with respect to the interpolated graphs
507
grads = tape.gradient(logits, [adjacency_interp, features_interp])
508
# Compute the gradient penalty
509
grads_adjacency_penalty = (1 - ops.norm(grads[0], axis=1)) ** 2
510
grads_features_penalty = (1 - ops.norm(grads[1], axis=2)) ** 2
511
return ops.mean(
512
ops.mean(grads_adjacency_penalty, axis=(-2, -1))
513
+ ops.mean(grads_features_penalty, axis=(-1))
514
)
515
516
def inference(self, batch_size):
517
z = keras.random.normal(
518
shape=(batch_size, LATENT_DIM), seed=self.seed_generator
519
)
520
reconstruction_adjacency, reconstruction_features = model.decoder.predict(z)
521
# obtain one-hot encoded adjacency tensor
522
adjacency = ops.argmax(reconstruction_adjacency, axis=1)
523
adjacency = ops.one_hot(adjacency, num_classes=BOND_DIM, axis=1)
524
# Remove potential self-loops from adjacency
525
adjacency = adjacency * (1.0 - ops.eye(NUM_ATOMS, dtype="float32")[None, None])
526
# obtain one-hot encoded feature tensor
527
features = ops.argmax(reconstruction_features, axis=2)
528
features = ops.one_hot(features, num_classes=ATOM_DIM, axis=2)
529
return [
530
graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
531
for i in range(batch_size)
532
]
533
534
def call(self, inputs):
535
z_mean, log_var = self.encoder(inputs)
536
z = self.sampling_layer([z_mean, log_var])
537
538
gen_adjacency, gen_features = self.decoder(z)
539
540
property_pred = self.property_prediction_layer(z_mean)
541
542
return z_mean, log_var, property_pred, gen_adjacency, gen_features
543
544
545
"""
546
## Train the model
547
"""
548
549
vae_optimizer = keras.optimizers.Adam(learning_rate=VAE_LR)
550
551
encoder = get_encoder(
552
gconv_units=[9],
553
adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
554
feature_shape=(NUM_ATOMS, ATOM_DIM),
555
latent_dim=LATENT_DIM,
556
dense_units=[512],
557
dropout_rate=0.0,
558
)
559
decoder = get_decoder(
560
dense_units=[128, 256, 512],
561
dropout_rate=0.2,
562
latent_dim=LATENT_DIM,
563
adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
564
feature_shape=(NUM_ATOMS, ATOM_DIM),
565
)
566
567
model = MoleculeGenerator(encoder, decoder, MAX_MOLSIZE)
568
569
model.compile(vae_optimizer)
570
history = model.fit([adjacency_tensor, feature_tensor, qed_tensor], epochs=EPOCHS)
571
572
"""
573
## Inference
574
575
We use our model to generate new valid molecules from different points of the latent space.
576
"""
577
578
"""
579
### Generate unique Molecules with the model
580
"""
581
582
molecules = model.inference(1000)
583
584
MolsToGridImage(
585
[m for m in molecules if m is not None][:1000], molsPerRow=5, subImgSize=(260, 160)
586
)
587
588
"""
589
### Display latent space clusters with respect to molecular properties (QAE)
590
"""
591
592
593
def plot_latent(vae, data, labels):
594
# display a 2D plot of the property in the latent space
595
z_mean, _ = vae.encoder.predict(data)
596
plt.figure(figsize=(12, 10))
597
plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
598
plt.colorbar()
599
plt.xlabel("z[0]")
600
plt.ylabel("z[1]")
601
plt.show()
602
603
604
plot_latent(model, [adjacency_tensor[:8000], feature_tensor[:8000]], qed_tensor[:8000])
605
606
"""
607
## Conclusion
608
609
In this example, we combined model architectures from two papers,
610
"Automatic chemical design using a data-driven continuous representation of
611
molecules" from 2016 and the "MolGAN" paper from 2018. The former paper
612
treats SMILES inputs as strings and seeks to generate molecule strings in SMILES format,
613
while the later paper considers SMILES inputs as graphs (a combination of adjacency
614
matrices and feature matrices) and seeks to generate molecules as graphs.
615
616
This hybrid approach enables a new type of directed gradient-based search through chemical space.
617
618
Example available on HuggingFace
619
620
| Trained Model | Demo |
621
| :--: | :--: |
622
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-molecule%20generation%20with%20VAE-black.svg)](https://huggingface.co/keras-io/drug-molecule-generation-with-VAE) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-molecule%20generation%20with%20VAE-black.svg)](https://huggingface.co/spaces/keras-io/generating-drug-molecule-with-VAE) |
623
"""
624
625