Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/midi_generation_with_transformer.py
3507 views
1
"""
2
Title: Music Generation with Transformer Models
3
Author: [Joaquin Jimenez](https://github.com/johacks/)
4
Date created: 2024/11/22
5
Last modified: 2024/11/26
6
Description: Use a Transformer model to train on MIDI data and generate music sequences.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this tutorial, we learn how to build a music generation model using a
14
Transformer decode-only architecture.
15
The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro)
16
and implemented using keras 3.
17
In the process, we explore MIDI tokenization, and relative global attention mechanisms.
18
19
This example is based on the paper "Music Transformer" by Huang et al. (2018).
20
Check out the original [paper](https://arxiv.org/abs/1809.04281) and
21
[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0).
22
"""
23
24
"""
25
## Setup
26
27
Before we start, let's import and install all the libraries we need.
28
"""
29
30
"""shell
31
pip install -qq midi_neural_processor
32
pip install -qq keras_hub
33
pip install -qq "keras>=3.6.0" # Allows use of keras.utils.Config.
34
"""
35
36
"""
37
### Optional dependencies
38
39
To hear the audio, install the following additional dependencies:
40
"""
41
42
"""shell
43
sudo apt-get -qq install -y fluidsynth 2> /dev/null
44
pip install -qq pyfluidsynth scipy
45
"""
46
47
import os
48
import random
49
import tempfile
50
51
import keras
52
import midi_neural_processor.processor as midi_tokenizer
53
import numpy as np
54
from keras import callbacks, layers, ops, optimizers, utils
55
from keras_hub import layers as hub_layers
56
from os import path
57
58
"""
59
## Configuration
60
61
Lets define the configuration for the model and the dataset to be used in this example.
62
"""
63
event_range = midi_tokenizer.RANGE_NOTE_ON
64
event_range += midi_tokenizer.RANGE_NOTE_OFF
65
event_range += midi_tokenizer.RANGE_TIME_SHIFT
66
event_range += midi_tokenizer.RANGE_VEL
67
CONFIG = utils.Config(
68
max_sequence_len=2048,
69
embedding_dim=256,
70
num_transformer_blocks=6,
71
batch_size=6,
72
token_pad=event_range,
73
token_start_of_sentence=event_range + 1,
74
token_end_of_sentence=event_range + 2,
75
vocabulary_size=event_range + 3,
76
model_out="tmp/music_transformer.keras",
77
seed=42,
78
)
79
utils.set_random_seed(CONFIG.seed)
80
81
82
"""
83
## Maestro dataset
84
85
The Maestro dataset contains MIDI files for piano performances.
86
87
### Download the dataset
88
89
We now download and extract the dataset, then move the MIDI files to a new directory.
90
"""
91
92
93
def download_maestro(output_dir=None):
94
"""Download the Maestro MIDI dataset.
95
Extracted from: https://magenta.tensorflow.org/datasets/maestro
96
"""
97
# Ensure the output directory exists
98
output_dir = tempfile.mkdtemp() if output_dir is None else output_dir
99
os.makedirs(output_dir, exist_ok=True)
100
101
# Download and extract zip file
102
dir = utils.get_file(
103
origin="https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip",
104
extract=True,
105
)
106
107
# Gather all MIDI files
108
midi_files, file_paths = set(), list()
109
for root, _, files in os.walk(dir):
110
for file in files:
111
if file.lower().endswith(".midi") or file.lower().endswith(".mid"):
112
midi_files.add(path.join(root, file))
113
114
# Move the files to the output directory
115
for file in sorted(midi_files):
116
file_paths.append(new_path := path.join(output_dir, path.basename(file)))
117
os.rename(file, new_path)
118
return file_paths
119
120
121
paths = list(sorted(download_maestro(output_dir="datasets/maestro")))
122
output_dir = path.dirname(paths[0])
123
124
125
"""
126
### Split the dataset
127
128
We can now split the dataset into training and validation sets.
129
"""
130
131
indices = np.random.permutation(len(paths))
132
split = int(len(paths) * 0.1)
133
train_paths = [paths[i] for i in indices[split:]]
134
val_paths = [paths[i] for i in indices[:split]]
135
136
"""
137
### Hear a MIDI file
138
139
We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio.
140
This allows us to listen to the data samples before and after processing.
141
142
The following dependencies are required to play the audio:
143
- fluidsynth: `sudo apt install -y fluidsynth`
144
- pyfluidsynth, scipy: `pip install pyfluidsynth scipy`
145
"""
146
147
148
def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None):
149
import pretty_midi
150
from scipy.io.wavfile import write as write_wav
151
from IPython.display import Audio
152
153
# Create the audio waveform
154
pretty_midi_file = pretty_midi.PrettyMIDI(midi_path)
155
waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate]
156
157
# Display the audio if no path is provided
158
if out_dir is None:
159
# IPython display
160
return Audio(waveform, rate=sampling_rate)
161
162
# Save the audio to a file
163
os.makedirs(out_dir, exist_ok=True)
164
audio_path = path.join(out_dir, path.basename(midi_path).split(".")[0] + ".wav")
165
write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16))
166
return audio_path
167
168
169
print(visualize_midi(train_paths[0], out_dir="tmp/")) # Saved audio path
170
visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook
171
172
173
"""
174
### Tokenize the data
175
176
We now preprocess the MIDI files into a tokenized format for training.
177
"""
178
179
180
def encode_midi_task(midi_path):
181
"""Define a task that tokenizes a MIDI file."""
182
import midi_neural_processor.processor as midi_tokenizer
183
184
return midi_tokenizer.encode_midi(midi_path)
185
186
187
def preprocess_midi_files(file_paths, save_dir=None):
188
"""Preprocess a list of MIDI files and save the notes to a file."""
189
from multiprocessing import Pool, cpu_count
190
191
# Assume all files are in the same directory and save to the same directory
192
save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir
193
os.makedirs(save_dir, exist_ok=True)
194
195
# Check if the notes have already been preprocessed
196
output_file = path.join(save_dir, "notes.npz")
197
if path.exists(output_file):
198
npz_file = np.load(output_file)
199
return [npz_file[key] for key in npz_file.keys()]
200
201
# Preprocess the MIDI files in parallel
202
progbar = utils.Progbar(len(file_paths), unit_name="MIDI_file", interval=5)
203
pool = Pool(cpu_count() - 1)
204
all_notes = []
205
for notes in pool.imap_unordered(encode_midi_task, file_paths):
206
progbar.add(1)
207
all_notes.append(np.array(notes))
208
209
# Save the notes to a file
210
np.savez(output_file, *all_notes)
211
return all_notes
212
213
214
train_midis = preprocess_midi_files(train_paths, path.join(output_dir, "train"))
215
val_midis = preprocess_midi_files(val_paths, path.join(output_dir, "val"))
216
217
218
"""
219
### Dataset objects
220
221
We now define a dataset class that yields batches of input sequences and target sequences.
222
"""
223
224
225
class MidiDataset(utils.PyDataset):
226
"""A dataset for MIDI files that yields batches of input sequences and target sequences."""
227
228
def __init__(
229
self,
230
encoded_midis,
231
batch_size=CONFIG.batch_size,
232
max_sequence_len=CONFIG.max_sequence_len,
233
):
234
super(MidiDataset, self).__init__()
235
self.batch_size = batch_size
236
self.max_sequence_len = max_sequence_len
237
self.encoded_midis = encoded_midis
238
batches, last_batch_size = divmod(len(encoded_midis), batch_size)
239
self._num_batches = batches + int(last_batch_size > 0)
240
241
def __len__(self):
242
"""Get the number of batches."""
243
return self._num_batches
244
245
def __getitem__(self, idx):
246
"""Generate random inputs and corresponding targets for the model."""
247
# Same as in the original paper, we always get a random batch.
248
# See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py
249
batch = random.sample(self.encoded_midis, k=self.batch_size)
250
251
# Convert the batch to sequences
252
batch_data = [
253
self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch
254
]
255
batch_data = np.array(batch_data)
256
257
# Split the data into input and target sequences
258
return batch_data[:, :-1], batch_data[:, 1:]
259
260
def _get_sequence(self, data, max_length):
261
"""Get a random sequence of notes from a file."""
262
# Truncate or pad the sequence
263
if len(data) > max_length:
264
start = random.randrange(0, len(data) - max_length)
265
data = data[start : start + max_length]
266
elif len(data) < max_length:
267
data = np.append(data, CONFIG.token_end_of_sentence)
268
269
# Pad the sequence if necessary
270
if len(data) < max_length:
271
data = np.concatenate(
272
(data, np.full(max_length - len(data), CONFIG.token_pad))
273
)
274
return np.asanyarray(data, dtype="int32")
275
276
277
train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis)
278
279
280
"""
281
## Model definition
282
283
It is time to define the model architecture. We use a Transformer decoder
284
architecture with a custom attention mechanism, relative global attention.
285
286
### Relative Global Attention
287
288
The following code implements the Relative Global Attention layer. It is used
289
in place of the standard multi-head attention layer in the Transformer decoder.
290
The main difference is that it includes a relative positional encoding that
291
allows the model to learn relative positional information between tokens.
292
"""
293
294
295
@keras.utils.register_keras_serializable()
296
class RelativeGlobalAttention(layers.Layer):
297
"""
298
From Music Transformer (Huang et al., 2018)
299
https://arxiv.org/abs/1809.04281
300
"""
301
302
def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs):
303
super().__init__(**kwargs)
304
self.key_length = None
305
self.max_sequence_len = max_sequence_len
306
self.relative_embedding = None
307
self.num_heads = num_heads
308
self.embedding_dim = embedding_dim
309
self.head_dim = embedding_dim // num_heads
310
self.query_dense = layers.Dense(int(self.embedding_dim))
311
self.key_dense = layers.Dense(int(self.embedding_dim))
312
self.value_dense = layers.Dense(int(self.embedding_dim))
313
self.output_dense = layers.Dense(embedding_dim, name="output")
314
315
def build(self, input_shape):
316
self.query_length = input_shape[0][1]
317
self.key_length = input_shape[1][1]
318
self.relative_embedding = self.add_weight(
319
(self.max_sequence_len, int(self.head_dim)), name="relative_embedding"
320
)
321
322
def _apply_dense_layer_and_split_heads(self, inputs, dense_layer):
323
# Apply linear transformation
324
inputs = dense_layer(inputs)
325
new_shape = ops.shape(inputs)
326
# Reshape to split by attention heads
327
reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1))
328
# Transpose for head-first format
329
return ops.transpose(reshaped, (0, 2, 1, 3))
330
331
def call(self, inputs, mask=None):
332
# Compute Q, K, V: Batch, head, sequence, features
333
query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense)
334
key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense)
335
value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense)
336
337
# Compute scaled dot-product attention scores
338
attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2]))
339
340
# Compute relative positional encoding and combine with attention scores
341
start_idx = max(0, self.max_sequence_len - ops.shape(query)[2])
342
relative_embedding = self.relative_embedding[start_idx:, :]
343
attention_scores += self._compute_attention_scores(query, relative_embedding)
344
logits = attention_scores / ops.sqrt(self.head_dim)
345
346
# Apply mask if provided
347
if mask is not None:
348
logits += ops.cast(mask, "float32") * -1e9
349
350
# Compute attention weights
351
attention_weights = ops.nn.softmax(logits, axis=-1)
352
attention_output = ops.matmul(attention_weights, value)
353
354
# Merge heads and apply final linear transformation
355
merged_attention = ops.transpose(attention_output, (0, 2, 1, 3))
356
merged_attention = ops.reshape(
357
merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim)
358
)
359
output = self.output_dense(merged_attention)
360
361
return output, attention_weights
362
363
def _compute_attention_scores(self, query, relative_embedding):
364
"""
365
Compute relative attention scores using positional encodings.
366
"""
367
relative_scores = ops.einsum("bhld, md->bhlm", query, relative_embedding)
368
relative_scores = self._apply_mask_to_relative_scores(relative_scores)
369
return self._skew_attention_scores(relative_scores)
370
371
def _apply_mask_to_relative_scores(self, scores):
372
"""
373
Apply masking to relative positional scores to ignore future positions.
374
"""
375
mask = ops.flip(
376
ops.tri(scores.shape[-2], scores.shape[-1], dtype="float32"), axis=1
377
)
378
return mask * scores
379
380
def _skew_attention_scores(self, scores):
381
"""
382
Perform skewing operation to align relative attention scores with the sequence.
383
"""
384
padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0)))
385
padded_shape = ops.shape(padded_scores)
386
reshaped_scores = ops.reshape(
387
padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2])
388
)
389
skewed_scores = reshaped_scores[:, :, 1:, :]
390
391
if self.key_length > self.query_length:
392
size_diff = self.key_length - self.query_length
393
return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]])
394
else:
395
return skewed_scores[:, :, :, : self.key_length]
396
397
398
"""
399
### Decoder Layer
400
401
Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like
402
the standard Transformer decoder layer but with the custom attention mechanism.
403
"""
404
405
406
@keras.utils.register_keras_serializable()
407
class DecoderLayer(layers.Layer):
408
def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1):
409
super(DecoderLayer, self).__init__()
410
411
# Initialize attributes
412
self.embedding_dim = embedding_dim
413
self.num_heads = num_heads
414
self.max_sequence_len = max_sequence_len
415
416
# Initialize layers
417
self.relative_global_attention_1 = RelativeGlobalAttention(
418
num_heads, embedding_dim, max_sequence_len
419
)
420
421
self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, "relu")
422
self.feed_forward_network_pos = layers.Dense(self.embedding_dim)
423
424
self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6)
425
self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6)
426
427
self.dropout_1 = layers.Dropout(dropout)
428
self.dropout_2 = layers.Dropout(dropout)
429
430
def call(self, inputs, mask=None, training=False):
431
# Attention block. Inputs are (query, key, value)
432
attention_out, attention_weights = self.relative_global_attention_1(
433
(inputs, inputs, inputs), mask=mask
434
)
435
attention_out = self.dropout_1(attention_out, training=training)
436
attention_out_normalized = self.layer_normalization_1(attention_out + inputs)
437
438
ffn_out = self.feed_forward_network_pre(attention_out)
439
ffn_out = self.feed_forward_network_pos(ffn_out)
440
ffn_out = self.dropout_2(ffn_out, training=training)
441
out = self.layer_normalization_2(attention_out_normalized + ffn_out)
442
443
return out, attention_weights
444
445
446
"""
447
### Decoder
448
449
The Decoder layer is composed of multiple DecoderLayer blocks. It also includes
450
an embedding layer that converts our tokenized input into an embedding representation.
451
"""
452
453
454
@keras.utils.register_keras_serializable()
455
class Decoder(layers.Layer):
456
def __init__(
457
self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout
458
):
459
super(Decoder, self).__init__()
460
461
self.embedding_dim = embedding_dim
462
self.num_blocks = num_blocks
463
464
self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim)
465
self.positional_encoding = hub_layers.SinePositionEncoding()
466
467
self.decode_layers = [
468
DecoderLayer(
469
embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout
470
)
471
for _ in range(num_blocks)
472
]
473
self.dropout = layers.Dropout(dropout)
474
475
def call(self, inputs, mask=None, training=False, return_attention_weights=False):
476
weights = []
477
478
# Adding embedding and position encoding.
479
x = self.embedding(inputs)
480
x = x * ops.sqrt(ops.cast(self.embedding_dim, "float32"))
481
x = x + self.positional_encoding(x)
482
x = self.dropout(x, training=training)
483
484
# Passing through the transformer blocks.
485
for i in range(self.num_blocks):
486
x, w = self.decode_layers[i](x, mask=mask, training=training)
487
weights.append(w)
488
if return_attention_weights:
489
return x, weights
490
return x
491
492
493
"""
494
### Music Transformer Decoder
495
496
With the above layers defined, we can now define the MusicTransformerDecoder model. It applies
497
a linear transformation to the output of the decoder to get the logits for each token.
498
"""
499
500
501
@keras.utils.register_keras_serializable()
502
class MusicTransformerDecoder(keras.Model):
503
def __init__(
504
self,
505
embedding_dim=CONFIG.embedding_dim,
506
vocabulary_size=CONFIG.vocabulary_size,
507
num_blocks=CONFIG.num_transformer_blocks,
508
max_sequence_len=CONFIG.max_sequence_len,
509
dropout=0.2,
510
):
511
# Initialize attributes
512
super(MusicTransformerDecoder, self).__init__()
513
self.embedding_dim = embedding_dim
514
self.vocabulary_size = vocabulary_size
515
self.num_blocks = num_blocks
516
self.max_sequence_len = max_sequence_len
517
518
# Initialize layers
519
# Transformer decoder
520
self.decoder = Decoder(
521
embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout
522
)
523
# Output layer
524
self.fc = layers.Dense(self.vocabulary_size, activation=None, name="output")
525
526
@staticmethod
527
def get_look_ahead_mask(max_sequence_len, inputs):
528
sequence_length = min(max_sequence_len, inputs.shape[1])
529
sequence_mask = ops.logical_not(
530
ops.tri(sequence_length, sequence_length, dtype="bool")
531
)
532
533
inputs = ops.cast(inputs[:, None, None, :], "int32")
534
output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad
535
decoder_output_mask = ops.equal(inputs, output_pad_tensor)
536
return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), "int32")
537
538
def call(self, inputs, training=False):
539
mask = self.get_look_ahead_mask(self.max_sequence_len, inputs)
540
decoding = self.decoder(
541
inputs, mask=mask, training=training, return_attention_weights=False
542
)
543
return self.fc(decoding)
544
545
# --- Sequence generation methods
546
547
def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5):
548
inputs = ops.convert_to_tensor([inputs])
549
550
# Generate a new token using output distribution at given index
551
def generate_token(inputs, end_idx):
552
distribution = ops.stop_gradient(self.call(inputs)[0, end_idx])
553
554
# Select the top-k tokens and their probabilities
555
top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k)
556
557
# Sample from the top-k probabilities
558
new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1)
559
return ops.take(top_k_indices, new_token_idx[0])
560
561
# Compute the number of tokens to add
562
added_tokens = min(length, self.max_sequence_len - inputs.shape[1])
563
progbar = utils.Progbar(added_tokens, unit_name="token", interval=5)
564
565
# Pad the input sequence that will be filled with generated tokens
566
out = ops.pad(inputs, ((0, 0), (0, added_tokens)), "constant", CONFIG.token_pad)
567
568
# Generate tokens using top-k sampling
569
for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens):
570
token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype)
571
out = ops.scatter_update(out, ((0, token_idx + 1),), token)
572
progbar.add(1)
573
574
return ops.convert_to_numpy(out[0])
575
576
# --- Serialization methods
577
578
def get_config(self):
579
atts = ["embedding_dim", "vocabulary_size", "num_blocks", "max_sequence_len"]
580
return {a: getattr(self, a) for a in atts}
581
582
@classmethod
583
def from_config(cls, config):
584
return cls(**config)
585
586
587
"""
588
### Loss function
589
590
We define a custom loss function that computes the categorical cross-entropy
591
loss for the model. It is computed only for non-padding tokens and uses
592
`from_logits=True` since the model outputs logits.
593
"""
594
595
596
@keras.utils.register_keras_serializable()
597
def train_loss(y_true, y_pred):
598
mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), "float32")
599
y_true = ops.one_hot(ops.cast(y_true, "int32"), CONFIG.vocabulary_size)
600
return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask
601
602
603
"""
604
### Learning rate schedule
605
606
Following the Music Transformer paper, we define an adapted exponential decay
607
learning rate schedule that takes into account the embedding dimension.
608
"""
609
610
611
@keras.utils.register_keras_serializable()
612
class CustomSchedule(optimizers.schedules.LearningRateSchedule):
613
def __init__(self, embedding_dim, warmup_steps=4000):
614
super(CustomSchedule, self).__init__()
615
616
self.embedding_dim = embedding_dim
617
self.warmup_steps = warmup_steps
618
619
self._embedding_dim = ops.cast(self.embedding_dim, "float32")
620
# Numerical stability adjustment on torch, which is less precise
621
self._lr_adjust = 0.1 if keras.backend.backend() == "torch" else 1.0
622
623
def get_config(self):
624
return {"embedding_dim": self.embedding_dim, "warmup_steps": self.warmup_steps}
625
626
def __call__(self, step):
627
step_rsqrt = ops.rsqrt(ops.cast(step, "float32"))
628
warmup_adjust = step * (self.warmup_steps**-1.5)
629
output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust)
630
return self._lr_adjust * output
631
632
633
"""
634
## Training the model
635
636
We can now train the model on the Maestro dataset. First, we define a training
637
function. This function compiles the model, trains it, and saves the best model
638
checkpoint. This way, we can continue training from the best model checkpoint
639
if needed.
640
"""
641
642
643
def train_model(model, train_ds, val_ds, epochs=15):
644
# Configure optimizer
645
learning_rate = CustomSchedule(CONFIG.embedding_dim)
646
optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
647
648
# Compile the model
649
model.compile(optimizer=optimizer, loss=train_loss)
650
651
# Train the model
652
save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True)
653
model.fit(
654
train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2
655
)
656
return model
657
658
659
"""
660
We can now train the model on the Maestro dataset. If a model checkpoint exists,
661
we can load it and continue training.
662
"""
663
if path.exists(CONFIG.model_out):
664
model = keras.models.load_model(CONFIG.model_out)
665
# Comment out to continue model training from the checkpoint
666
# train_model(model, train_dataset, val_dataset, epochs=10)
667
else:
668
# Train the model
669
model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset)
670
671
672
"""
673
## Generate music
674
675
We can now generate music using the trained model. We use an existing MIDI file
676
as a seed and generate a new sequence.
677
"""
678
679
680
def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None):
681
# Ensure the output directory exists
682
out_dir = out_dir if out_dir is not None else tempfile.mkdtemp()
683
os.makedirs(out_dir, exist_ok=True)
684
685
# Get some tokens from the MIDI file
686
inputs = midi_tokenizer.encode_midi(seed_path)[100:125]
687
print(f"Seed tokens: {inputs}")
688
689
# Generate music that follows the input tokens until the maximum length
690
result = model.generate(inputs, length=length, top_k=top_k)
691
692
output_path = path.join(out_dir, path.basename(seed_path).split(".")[0] + ".mid")
693
midi_tokenizer.decode_midi(result, output_path)
694
return output_path
695
696
697
output_file = generate_music(model, val_paths[-1], out_dir="tmp/", top_k=15)
698
print(visualize_midi(output_file, out_dir="tmp/")) # Saved audio path
699
visualize_midi(output_file) # Display the audio if in a Jupyter notebook
700
701
"""
702
## Conclusion
703
704
In this example, we learned how to build a music generation model using a custom
705
Transformer decoder architecture.
706
707
We did it following the Music Transformer paper by Huang et al. (2018).
708
To do so we had to:
709
710
- Define a custom loss function and learning rate schedule.
711
- Define a custom attention mechanism.
712
- Preprocess MIDI files into a tokenized format.
713
714
After training the model on the Maestro dataset, we generated music sequences
715
using a seed MIDI file.
716
717
### Next steps
718
719
We could further improve inference times by caching attention weights during the
720
forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the
721
`CachedMultiHeadAttention` layer.
722
"""
723
724