Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/dcgan_overriding_train_step.py
3507 views
1
"""
2
Title: DCGAN to generate face images
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/04/29
5
Last modified: 2023/12/21
6
Description: A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import keras
15
import tensorflow as tf
16
17
from keras import layers
18
from keras import ops
19
import matplotlib.pyplot as plt
20
import os
21
import gdown
22
from zipfile import ZipFile
23
24
25
"""
26
## Prepare CelebA data
27
28
We'll use face images from the CelebA dataset, resized to 64x64.
29
"""
30
31
os.makedirs("celeba_gan")
32
33
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
34
output = "celeba_gan/data.zip"
35
gdown.download(url, output, quiet=True)
36
37
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
38
zipobj.extractall("celeba_gan")
39
40
"""
41
Create a dataset from our folder, and rescale the images to the [0-1] range:
42
"""
43
44
dataset = keras.utils.image_dataset_from_directory(
45
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
46
)
47
dataset = dataset.map(lambda x: x / 255.0)
48
49
50
"""
51
Let's display a sample image:
52
"""
53
54
55
for x in dataset:
56
plt.axis("off")
57
plt.imshow((x.numpy() * 255).astype("int32")[0])
58
break
59
60
61
"""
62
## Create the discriminator
63
64
It maps a 64x64 image to a binary classification score.
65
"""
66
67
discriminator = keras.Sequential(
68
[
69
keras.Input(shape=(64, 64, 3)),
70
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
71
layers.LeakyReLU(negative_slope=0.2),
72
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
73
layers.LeakyReLU(negative_slope=0.2),
74
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
75
layers.LeakyReLU(negative_slope=0.2),
76
layers.Flatten(),
77
layers.Dropout(0.2),
78
layers.Dense(1, activation="sigmoid"),
79
],
80
name="discriminator",
81
)
82
discriminator.summary()
83
84
"""
85
## Create the generator
86
87
It mirrors the discriminator, replacing `Conv2D` layers with `Conv2DTranspose` layers.
88
"""
89
90
latent_dim = 128
91
92
generator = keras.Sequential(
93
[
94
keras.Input(shape=(latent_dim,)),
95
layers.Dense(8 * 8 * 128),
96
layers.Reshape((8, 8, 128)),
97
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
98
layers.LeakyReLU(negative_slope=0.2),
99
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
100
layers.LeakyReLU(negative_slope=0.2),
101
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
102
layers.LeakyReLU(negative_slope=0.2),
103
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
104
],
105
name="generator",
106
)
107
generator.summary()
108
109
"""
110
## Override `train_step`
111
"""
112
113
114
class GAN(keras.Model):
115
def __init__(self, discriminator, generator, latent_dim):
116
super().__init__()
117
self.discriminator = discriminator
118
self.generator = generator
119
self.latent_dim = latent_dim
120
self.seed_generator = keras.random.SeedGenerator(1337)
121
122
def compile(self, d_optimizer, g_optimizer, loss_fn):
123
super().compile()
124
self.d_optimizer = d_optimizer
125
self.g_optimizer = g_optimizer
126
self.loss_fn = loss_fn
127
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
128
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
129
130
@property
131
def metrics(self):
132
return [self.d_loss_metric, self.g_loss_metric]
133
134
def train_step(self, real_images):
135
# Sample random points in the latent space
136
batch_size = ops.shape(real_images)[0]
137
random_latent_vectors = keras.random.normal(
138
shape=(batch_size, self.latent_dim), seed=self.seed_generator
139
)
140
141
# Decode them to fake images
142
generated_images = self.generator(random_latent_vectors)
143
144
# Combine them with real images
145
combined_images = ops.concatenate([generated_images, real_images], axis=0)
146
147
# Assemble labels discriminating real from fake images
148
labels = ops.concatenate(
149
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
150
)
151
# Add random noise to the labels - important trick!
152
labels += 0.05 * tf.random.uniform(tf.shape(labels))
153
154
# Train the discriminator
155
with tf.GradientTape() as tape:
156
predictions = self.discriminator(combined_images)
157
d_loss = self.loss_fn(labels, predictions)
158
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
159
self.d_optimizer.apply_gradients(
160
zip(grads, self.discriminator.trainable_weights)
161
)
162
163
# Sample random points in the latent space
164
random_latent_vectors = keras.random.normal(
165
shape=(batch_size, self.latent_dim), seed=self.seed_generator
166
)
167
168
# Assemble labels that say "all real images"
169
misleading_labels = ops.zeros((batch_size, 1))
170
171
# Train the generator (note that we should *not* update the weights
172
# of the discriminator)!
173
with tf.GradientTape() as tape:
174
predictions = self.discriminator(self.generator(random_latent_vectors))
175
g_loss = self.loss_fn(misleading_labels, predictions)
176
grads = tape.gradient(g_loss, self.generator.trainable_weights)
177
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
178
179
# Update metrics
180
self.d_loss_metric.update_state(d_loss)
181
self.g_loss_metric.update_state(g_loss)
182
return {
183
"d_loss": self.d_loss_metric.result(),
184
"g_loss": self.g_loss_metric.result(),
185
}
186
187
188
"""
189
## Create a callback that periodically saves generated images
190
"""
191
192
193
class GANMonitor(keras.callbacks.Callback):
194
def __init__(self, num_img=3, latent_dim=128):
195
self.num_img = num_img
196
self.latent_dim = latent_dim
197
self.seed_generator = keras.random.SeedGenerator(42)
198
199
def on_epoch_end(self, epoch, logs=None):
200
random_latent_vectors = keras.random.normal(
201
shape=(self.num_img, self.latent_dim), seed=self.seed_generator
202
)
203
generated_images = self.model.generator(random_latent_vectors)
204
generated_images *= 255
205
generated_images.numpy()
206
for i in range(self.num_img):
207
img = keras.utils.array_to_img(generated_images[i])
208
img.save("generated_img_%03d_%d.png" % (epoch, i))
209
210
211
"""
212
## Train the end-to-end model
213
"""
214
215
epochs = 1 # In practice, use ~100 epochs
216
217
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
218
gan.compile(
219
d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
220
g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
221
loss_fn=keras.losses.BinaryCrossentropy(),
222
)
223
224
gan.fit(
225
dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
226
)
227
228
"""
229
Some of the last generated images around epoch 30
230
(results keep improving after that):
231
232
![results](https://i.imgur.com/h5MtQZ7l.png)
233
"""
234
235