Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/stylegan.py
3507 views
1
"""
2
Title: Face image generation with StyleGAN
3
Author: [Soon-Yau Cheong](https://www.linkedin.com/in/soonyau/)
4
Date created: 2021/07/01
5
Last modified: 2021/07/01
6
Description: Implementation of StyleGAN for image generation.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The key idea of StyleGAN is to progressively increase the resolution of the generated
14
images and to incorporate style features in the generative process.This
15
[StyleGAN](https://arxiv.org/abs/1812.04948) implementation is based on the book
16
[Hands-on Image Generation with TensorFlow](https://www.amazon.com/dp/1838826785).
17
The code from the book's
18
[GitHub repository](https://github.com/PacktPublishing/Hands-On-Image-Generation-with-TensorFlow-2.0/tree/master/Chapter07)
19
was refactored to leverage a custom `train_step()` to enable
20
faster training time via compilation and distribution.
21
"""
22
23
"""
24
## Setup
25
"""
26
27
"""
28
### Install latest TFA
29
"""
30
"""shell
31
pip install tensorflow_addons
32
"""
33
34
import os
35
import numpy as np
36
import matplotlib.pyplot as plt
37
38
from functools import partial
39
40
import tensorflow as tf
41
from tensorflow import keras
42
from tensorflow.keras import layers
43
from tensorflow.keras.models import Sequential
44
from tensorflow_addons.layers import InstanceNormalization
45
46
import gdown
47
from zipfile import ZipFile
48
49
"""
50
## Prepare the dataset
51
52
In this example, we will train using the CelebA from the project GDrive.
53
"""
54
55
56
def log2(x):
57
return int(np.log2(x))
58
59
60
# we use different batch size for different resolution, so larger image size
61
# could fit into GPU memory. The keys is image resolution in log2
62
batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
63
# We adjust the train step accordingly
64
train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
65
66
67
os.makedirs("celeba_gan")
68
69
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
70
output = "celeba_gan/data.zip"
71
gdown.download(url, output, quiet=True)
72
73
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
74
zipobj.extractall("celeba_gan")
75
76
# Create a dataset from our folder, and rescale the images to the [0-1] range:
77
78
ds_train = keras.utils.image_dataset_from_directory(
79
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
80
)
81
82
83
def resize_image(res, image):
84
# only downsampling, so use nearest neighbor that is faster to run
85
image = tf.image.resize(
86
image, (res, res), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
87
)
88
image = tf.cast(image, tf.float32) / 127.5 - 1.0
89
return image
90
91
92
def create_dataloader(res):
93
batch_size = batch_sizes[log2(res)]
94
# NOTE: we unbatch the dataset so we can `batch()` it again with the `drop_remainder=True` option
95
# since the model only supports a single batch size
96
dl = ds_train.map(
97
partial(resize_image, res), num_parallel_calls=tf.data.AUTOTUNE
98
).unbatch()
99
dl = dl.shuffle(200).batch(batch_size, drop_remainder=True).prefetch(1).repeat()
100
return dl
101
102
103
"""
104
## Utility function to display images after each epoch
105
"""
106
107
108
def plot_images(images, log2_res, fname=""):
109
scales = {2: 0.5, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8}
110
scale = scales[log2_res]
111
112
grid_col = min(images.shape[0], int(32 // scale))
113
grid_row = 1
114
115
f, axarr = plt.subplots(
116
grid_row, grid_col, figsize=(grid_col * scale, grid_row * scale)
117
)
118
119
for row in range(grid_row):
120
ax = axarr if grid_row == 1 else axarr[row]
121
for col in range(grid_col):
122
ax[col].imshow(images[row * grid_col + col])
123
ax[col].axis("off")
124
plt.show()
125
if fname:
126
f.savefig(fname)
127
128
129
"""
130
## Custom Layers
131
132
The following are building blocks that will be used to construct the generators and
133
discriminators of the StyleGAN model.
134
"""
135
136
137
def fade_in(alpha, a, b):
138
return alpha * a + (1.0 - alpha) * b
139
140
141
def wasserstein_loss(y_true, y_pred):
142
return -tf.reduce_mean(y_true * y_pred)
143
144
145
def pixel_norm(x, epsilon=1e-8):
146
return x / tf.math.sqrt(tf.reduce_mean(x**2, axis=-1, keepdims=True) + epsilon)
147
148
149
def minibatch_std(input_tensor, epsilon=1e-8):
150
n, h, w, c = tf.shape(input_tensor)
151
group_size = tf.minimum(4, n)
152
x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
153
group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
154
group_std = tf.sqrt(group_var + epsilon)
155
avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
156
x = tf.tile(avg_std, [group_size, h, w, 1])
157
return tf.concat([input_tensor, x], axis=-1)
158
159
160
class EqualizedConv(layers.Layer):
161
def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
162
super().__init__(**kwargs)
163
self.kernel = kernel
164
self.out_channels = out_channels
165
self.gain = gain
166
self.pad = kernel != 1
167
168
def build(self, input_shape):
169
self.in_channels = input_shape[-1]
170
initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
171
self.w = self.add_weight(
172
shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
173
initializer=initializer,
174
trainable=True,
175
name="kernel",
176
)
177
self.b = self.add_weight(
178
shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
179
)
180
fan_in = self.kernel * self.kernel * self.in_channels
181
self.scale = tf.sqrt(self.gain / fan_in)
182
183
def call(self, inputs):
184
if self.pad:
185
x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
186
else:
187
x = inputs
188
output = (
189
tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
190
)
191
return output
192
193
194
class EqualizedDense(layers.Layer):
195
def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
196
super().__init__(**kwargs)
197
self.units = units
198
self.gain = gain
199
self.learning_rate_multiplier = learning_rate_multiplier
200
201
def build(self, input_shape):
202
self.in_channels = input_shape[-1]
203
initializer = keras.initializers.RandomNormal(
204
mean=0.0, stddev=1.0 / self.learning_rate_multiplier
205
)
206
self.w = self.add_weight(
207
shape=[self.in_channels, self.units],
208
initializer=initializer,
209
trainable=True,
210
name="kernel",
211
)
212
self.b = self.add_weight(
213
shape=(self.units,), initializer="zeros", trainable=True, name="bias"
214
)
215
fan_in = self.in_channels
216
self.scale = tf.sqrt(self.gain / fan_in)
217
218
def call(self, inputs):
219
output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
220
return output * self.learning_rate_multiplier
221
222
223
class AddNoise(layers.Layer):
224
def build(self, input_shape):
225
n, h, w, c = input_shape[0]
226
initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
227
self.b = self.add_weight(
228
shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
229
)
230
231
def call(self, inputs):
232
x, noise = inputs
233
output = x + self.b * noise
234
return output
235
236
237
class AdaIN(layers.Layer):
238
def __init__(self, gain=1, **kwargs):
239
super().__init__(**kwargs)
240
self.gain = gain
241
242
def build(self, input_shapes):
243
x_shape = input_shapes[0]
244
w_shape = input_shapes[1]
245
246
self.w_channels = w_shape[-1]
247
self.x_channels = x_shape[-1]
248
249
self.dense_1 = EqualizedDense(self.x_channels, gain=1)
250
self.dense_2 = EqualizedDense(self.x_channels, gain=1)
251
252
def call(self, inputs):
253
x, w = inputs
254
ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
255
yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
256
return ys * x + yb
257
258
259
"""
260
Next we build the following:
261
262
- A model mapping to map the random noise into style code
263
- The generator
264
- The discriminator
265
266
For the generator, we build generator blocks at multiple resolutions,
267
e.g. 4x4, 8x8, ...up to 1024x1024. We only use 4x4 in the beginning
268
and we use progressively larger-resolution blocks as the training proceeds.
269
Same for the discriminator.
270
"""
271
272
273
def Mapping(num_stages, input_shape=512):
274
z = layers.Input(shape=(input_shape))
275
w = pixel_norm(z)
276
for i in range(8):
277
w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
278
w = layers.LeakyReLU(0.2)(w)
279
w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
280
return keras.Model(z, w, name="mapping")
281
282
283
class Generator:
284
def __init__(self, start_res_log2, target_res_log2):
285
self.start_res_log2 = start_res_log2
286
self.target_res_log2 = target_res_log2
287
self.num_stages = target_res_log2 - start_res_log2 + 1
288
# list of generator blocks at increasing resolution
289
self.g_blocks = []
290
# list of layers to convert g_block activation to RGB
291
self.to_rgb = []
292
# list of noise input of different resolutions into g_blocks
293
self.noise_inputs = []
294
# filter size to use at each stage, keys are log2(resolution)
295
self.filter_nums = {
296
0: 512,
297
1: 512,
298
2: 512, # 4x4
299
3: 512, # 8x8
300
4: 512, # 16x16
301
5: 512, # 32x32
302
6: 256, # 64x64
303
7: 128, # 128x128
304
8: 64, # 256x256
305
9: 32, # 512x512
306
10: 16,
307
} # 1024x1024
308
309
start_res = 2**start_res_log2
310
self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
311
self.g_input = layers.Input(self.input_shape, name="generator_input")
312
313
for i in range(start_res_log2, target_res_log2 + 1):
314
filter_num = self.filter_nums[i]
315
res = 2**i
316
self.noise_inputs.append(
317
layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
318
)
319
to_rgb = Sequential(
320
[
321
layers.InputLayer(input_shape=(res, res, filter_num)),
322
EqualizedConv(3, 1, gain=1),
323
],
324
name=f"to_rgb_{res}x{res}",
325
)
326
self.to_rgb.append(to_rgb)
327
is_base = i == self.start_res_log2
328
if is_base:
329
input_shape = (res, res, self.filter_nums[i - 1])
330
else:
331
input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
332
g_block = self.build_block(
333
filter_num, res=res, input_shape=input_shape, is_base=is_base
334
)
335
self.g_blocks.append(g_block)
336
337
def build_block(self, filter_num, res, input_shape, is_base):
338
input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
339
noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
340
w = layers.Input(shape=512)
341
x = input_tensor
342
343
if not is_base:
344
x = layers.UpSampling2D((2, 2))(x)
345
x = EqualizedConv(filter_num, 3)(x)
346
347
x = AddNoise()([x, noise])
348
x = layers.LeakyReLU(0.2)(x)
349
x = InstanceNormalization()(x)
350
x = AdaIN()([x, w])
351
352
x = EqualizedConv(filter_num, 3)(x)
353
x = AddNoise()([x, noise])
354
x = layers.LeakyReLU(0.2)(x)
355
x = InstanceNormalization()(x)
356
x = AdaIN()([x, w])
357
return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
358
359
def grow(self, res_log2):
360
res = 2**res_log2
361
362
num_stages = res_log2 - self.start_res_log2 + 1
363
w = layers.Input(shape=(self.num_stages, 512), name="w")
364
365
alpha = layers.Input(shape=(1), name="g_alpha")
366
x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
367
368
if num_stages == 1:
369
rgb = self.to_rgb[0](x)
370
else:
371
for i in range(1, num_stages - 1):
372
x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
373
374
old_rgb = self.to_rgb[num_stages - 2](x)
375
old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
376
377
i = num_stages - 1
378
x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
379
380
new_rgb = self.to_rgb[i](x)
381
382
rgb = fade_in(alpha[0], new_rgb, old_rgb)
383
384
return keras.Model(
385
[self.g_input, w, self.noise_inputs, alpha],
386
rgb,
387
name=f"generator_{res}_x_{res}",
388
)
389
390
391
class Discriminator:
392
def __init__(self, start_res_log2, target_res_log2):
393
self.start_res_log2 = start_res_log2
394
self.target_res_log2 = target_res_log2
395
self.num_stages = target_res_log2 - start_res_log2 + 1
396
# filter size to use at each stage, keys are log2(resolution)
397
self.filter_nums = {
398
0: 512,
399
1: 512,
400
2: 512, # 4x4
401
3: 512, # 8x8
402
4: 512, # 16x16
403
5: 512, # 32x32
404
6: 256, # 64x64
405
7: 128, # 128x128
406
8: 64, # 256x256
407
9: 32, # 512x512
408
10: 16,
409
} # 1024x1024
410
# list of discriminator blocks at increasing resolution
411
self.d_blocks = []
412
# list of layers to convert RGB into activation for d_blocks inputs
413
self.from_rgb = []
414
415
for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
416
res = 2**res_log2
417
filter_num = self.filter_nums[res_log2]
418
from_rgb = Sequential(
419
[
420
layers.InputLayer(
421
input_shape=(res, res, 3), name=f"from_rgb_input_{res}"
422
),
423
EqualizedConv(filter_num, 1),
424
layers.LeakyReLU(0.2),
425
],
426
name=f"from_rgb_{res}",
427
)
428
429
self.from_rgb.append(from_rgb)
430
431
input_shape = (res, res, filter_num)
432
if len(self.d_blocks) == 0:
433
d_block = self.build_base(filter_num, res)
434
else:
435
d_block = self.build_block(
436
filter_num, self.filter_nums[res_log2 - 1], res
437
)
438
439
self.d_blocks.append(d_block)
440
441
def build_base(self, filter_num, res):
442
input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
443
x = minibatch_std(input_tensor)
444
x = EqualizedConv(filter_num, 3)(x)
445
x = layers.LeakyReLU(0.2)(x)
446
x = layers.Flatten()(x)
447
x = EqualizedDense(filter_num)(x)
448
x = layers.LeakyReLU(0.2)(x)
449
x = EqualizedDense(1)(x)
450
return keras.Model(input_tensor, x, name=f"d_{res}")
451
452
def build_block(self, filter_num_1, filter_num_2, res):
453
input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
454
x = EqualizedConv(filter_num_1, 3)(input_tensor)
455
x = layers.LeakyReLU(0.2)(x)
456
x = EqualizedConv(filter_num_2)(x)
457
x = layers.LeakyReLU(0.2)(x)
458
x = layers.AveragePooling2D((2, 2))(x)
459
return keras.Model(input_tensor, x, name=f"d_{res}")
460
461
def grow(self, res_log2):
462
res = 2**res_log2
463
idx = res_log2 - self.start_res_log2
464
alpha = layers.Input(shape=(1), name="d_alpha")
465
input_image = layers.Input(shape=(res, res, 3), name="input_image")
466
x = self.from_rgb[idx](input_image)
467
x = self.d_blocks[idx](x)
468
if idx > 0:
469
idx -= 1
470
downsized_image = layers.AveragePooling2D((2, 2))(input_image)
471
y = self.from_rgb[idx](downsized_image)
472
x = fade_in(alpha[0], x, y)
473
474
for i in range(idx, -1, -1):
475
x = self.d_blocks[i](x)
476
return keras.Model([input_image, alpha], x, name=f"discriminator_{res}_x_{res}")
477
478
479
"""
480
## Build StyleGAN with custom train step
481
"""
482
483
484
class StyleGAN(tf.keras.Model):
485
def __init__(self, z_dim=512, target_res=64, start_res=4):
486
super().__init__()
487
self.z_dim = z_dim
488
489
self.target_res_log2 = log2(target_res)
490
self.start_res_log2 = log2(start_res)
491
self.current_res_log2 = self.target_res_log2
492
self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
493
494
self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
495
496
self.mapping = Mapping(num_stages=self.num_stages)
497
self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
498
self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
499
self.g_input_shape = self.g_builder.input_shape
500
501
self.phase = None
502
self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
503
504
self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
505
506
def grow_model(self, res):
507
tf.keras.backend.clear_session()
508
res_log2 = log2(res)
509
self.generator = self.g_builder.grow(res_log2)
510
self.discriminator = self.d_builder.grow(res_log2)
511
self.current_res_log2 = res_log2
512
print(f"\nModel resolution:{res}x{res}")
513
514
def compile(
515
self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
516
):
517
self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
518
self.steps_per_epoch = steps_per_epoch
519
if res != 2**self.current_res_log2:
520
self.grow_model(res)
521
self.d_optimizer = d_optimizer
522
self.g_optimizer = g_optimizer
523
524
self.train_step_counter.assign(0)
525
self.phase = phase
526
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
527
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
528
super().compile(*args, **kwargs)
529
530
@property
531
def metrics(self):
532
return [self.d_loss_metric, self.g_loss_metric]
533
534
def generate_noise(self, batch_size):
535
noise = [
536
tf.random.normal((batch_size, 2**res, 2**res, 1))
537
for res in range(self.start_res_log2, self.target_res_log2 + 1)
538
]
539
return noise
540
541
def gradient_loss(self, grad):
542
loss = tf.square(grad)
543
loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
544
loss = tf.sqrt(loss)
545
loss = tf.reduce_mean(tf.square(loss - 1))
546
return loss
547
548
def train_step(self, real_images):
549
self.train_step_counter.assign_add(1)
550
551
if self.phase == "TRANSITION":
552
self.alpha.assign(
553
tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
554
)
555
elif self.phase == "STABLE":
556
self.alpha.assign(1.0)
557
else:
558
raise NotImplementedError
559
alpha = tf.expand_dims(self.alpha, 0)
560
batch_size = tf.shape(real_images)[0]
561
real_labels = tf.ones(batch_size)
562
fake_labels = -tf.ones(batch_size)
563
564
z = tf.random.normal((batch_size, self.z_dim))
565
const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
566
noise = self.generate_noise(batch_size)
567
568
# generator
569
with tf.GradientTape() as g_tape:
570
w = self.mapping(z)
571
fake_images = self.generator([const_input, w, noise, alpha])
572
pred_fake = self.discriminator([fake_images, alpha])
573
g_loss = wasserstein_loss(real_labels, pred_fake)
574
575
trainable_weights = (
576
self.mapping.trainable_weights + self.generator.trainable_weights
577
)
578
gradients = g_tape.gradient(g_loss, trainable_weights)
579
self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
580
581
# discriminator
582
with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
583
# forward pass
584
pred_fake = self.discriminator([fake_images, alpha])
585
pred_real = self.discriminator([real_images, alpha])
586
587
epsilon = tf.random.uniform((batch_size, 1, 1, 1))
588
interpolates = epsilon * real_images + (1 - epsilon) * fake_images
589
gradient_tape.watch(interpolates)
590
pred_fake_grad = self.discriminator([interpolates, alpha])
591
592
# calculate losses
593
loss_fake = wasserstein_loss(fake_labels, pred_fake)
594
loss_real = wasserstein_loss(real_labels, pred_real)
595
loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
596
597
# gradient penalty
598
gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
599
gradient_penalty = self.loss_weights[
600
"gradient_penalty"
601
] * self.gradient_loss(gradients_fake)
602
603
# drift loss
604
all_pred = tf.concat([pred_fake, pred_real], axis=0)
605
drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred**2)
606
607
d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
608
609
gradients = total_tape.gradient(
610
d_loss, self.discriminator.trainable_weights
611
)
612
self.d_optimizer.apply_gradients(
613
zip(gradients, self.discriminator.trainable_weights)
614
)
615
616
# Update metrics
617
self.d_loss_metric.update_state(d_loss)
618
self.g_loss_metric.update_state(g_loss)
619
return {
620
"d_loss": self.d_loss_metric.result(),
621
"g_loss": self.g_loss_metric.result(),
622
}
623
624
def call(self, inputs: dict()):
625
style_code = inputs.get("style_code", None)
626
z = inputs.get("z", None)
627
noise = inputs.get("noise", None)
628
batch_size = inputs.get("batch_size", 1)
629
alpha = inputs.get("alpha", 1.0)
630
alpha = tf.expand_dims(alpha, 0)
631
if style_code is None:
632
if z is None:
633
z = tf.random.normal((batch_size, self.z_dim))
634
style_code = self.mapping(z)
635
636
if noise is None:
637
noise = self.generate_noise(batch_size)
638
639
# self.alpha.assign(alpha)
640
641
const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
642
images = self.generator([const_input, style_code, noise, alpha])
643
images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
644
645
return images
646
647
648
"""
649
## Training
650
651
We first build the StyleGAN at smallest resolution, such as 4x4 or 8x8. Then we
652
progressively grow the model to higher resolution by appending new generator and
653
discriminator blocks.
654
"""
655
656
START_RES = 4
657
TARGET_RES = 128
658
659
style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
660
661
"""
662
The training for each new resolution happens in two phases - "transition" and "stable".
663
In the transition phase, the features from the previous resolution are mixed with the
664
current resolution. This allows for a smoother transition when scaling up. We use each
665
epoch in `model.fit()` as a phase.
666
"""
667
668
669
def train(
670
start_res=START_RES,
671
target_res=TARGET_RES,
672
steps_per_epoch=5000,
673
display_images=True,
674
):
675
opt_cfg = {"learning_rate": 1e-3, "beta_1": 0.0, "beta_2": 0.99, "epsilon": 1e-8}
676
677
val_batch_size = 16
678
val_z = tf.random.normal((val_batch_size, style_gan.z_dim))
679
val_noise = style_gan.generate_noise(val_batch_size)
680
681
start_res_log2 = int(np.log2(start_res))
682
target_res_log2 = int(np.log2(target_res))
683
684
for res_log2 in range(start_res_log2, target_res_log2 + 1):
685
res = 2**res_log2
686
for phase in ["TRANSITION", "STABLE"]:
687
if res == start_res and phase == "TRANSITION":
688
continue
689
690
train_dl = create_dataloader(res)
691
692
steps = int(train_step_ratio[res_log2] * steps_per_epoch)
693
694
style_gan.compile(
695
d_optimizer=tf.keras.optimizers.legacy.Adam(**opt_cfg),
696
g_optimizer=tf.keras.optimizers.legacy.Adam(**opt_cfg),
697
loss_weights={"gradient_penalty": 10, "drift": 0.001},
698
steps_per_epoch=steps,
699
res=res,
700
phase=phase,
701
run_eagerly=False,
702
)
703
704
prefix = f"res_{res}x{res}_{style_gan.phase}"
705
706
ckpt_cb = keras.callbacks.ModelCheckpoint(
707
f"checkpoints/stylegan_{res}x{res}.ckpt",
708
save_weights_only=True,
709
verbose=0,
710
)
711
print(phase)
712
style_gan.fit(
713
train_dl, epochs=1, steps_per_epoch=steps, callbacks=[ckpt_cb]
714
)
715
716
if display_images:
717
images = style_gan({"z": val_z, "noise": val_noise, "alpha": 1.0})
718
plot_images(images, res_log2)
719
720
721
"""
722
StyleGAN can take a long time to train, in the code below, a small `steps_per_epoch`
723
value of 1 is used to sanity-check the code is working alright. In practice, a larger
724
`steps_per_epoch` value (over 10000)
725
is required to get decent results.
726
"""
727
728
train(start_res=4, target_res=16, steps_per_epoch=1, display_images=False)
729
730
"""
731
## Results
732
733
We can now run some inference using pre-trained 64x64 checkpoints. In general, the image
734
fidelity increases with the resolution. You can try to train this StyleGAN to resolutions
735
above 128x128 with the CelebA HQ dataset.
736
"""
737
738
url = "https://github.com/soon-yau/stylegan_keras/releases/download/keras_example_v1.0/stylegan_128x128.ckpt.zip"
739
740
weights_path = keras.utils.get_file(
741
"stylegan_128x128.ckpt.zip",
742
url,
743
extract=True,
744
cache_dir=os.path.abspath("."),
745
cache_subdir="pretrained",
746
)
747
748
style_gan.grow_model(128)
749
style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
750
751
tf.random.set_seed(196)
752
batch_size = 2
753
z = tf.random.normal((batch_size, style_gan.z_dim))
754
w = style_gan.mapping(z)
755
noise = style_gan.generate_noise(batch_size=batch_size)
756
images = style_gan({"style_code": w, "noise": noise, "alpha": 1.0})
757
plot_images(images, 5)
758
759
"""
760
## Style Mixing
761
762
We can also mix styles from two images to create a new image.
763
"""
764
765
alpha = 0.4
766
w_mix = np.expand_dims(alpha * w[0] + (1 - alpha) * w[1], 0)
767
noise_a = [np.expand_dims(n[0], 0) for n in noise]
768
mix_images = style_gan({"style_code": w_mix, "noise": noise_a})
769
image_row = np.hstack([images[0], images[1], mix_images[0]])
770
plt.figure(figsize=(9, 3))
771
plt.imshow(image_row)
772
plt.axis("off")
773
774