Path: blob/master/examples/generative/midi_generation_with_transformer.py
3507 views
"""1Title: Music Generation with Transformer Models2Author: [Joaquin Jimenez](https://github.com/johacks/)3Date created: 2024/11/224Last modified: 2024/11/265Description: Use a Transformer model to train on MIDI data and generate music sequences.6Accelerator: GPU7"""89"""10## Introduction1112In this tutorial, we learn how to build a music generation model using a13Transformer decode-only architecture.14The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro)15and implemented using keras 3.16In the process, we explore MIDI tokenization, and relative global attention mechanisms.1718This example is based on the paper "Music Transformer" by Huang et al. (2018).19Check out the original [paper](https://arxiv.org/abs/1809.04281) and20[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0).21"""2223"""24## Setup2526Before we start, let's import and install all the libraries we need.27"""2829"""shell30pip install -qq midi_neural_processor31pip install -qq keras_hub32pip install -qq "keras>=3.6.0" # Allows use of keras.utils.Config.33"""3435"""36### Optional dependencies3738To hear the audio, install the following additional dependencies:39"""4041"""shell42sudo apt-get -qq install -y fluidsynth 2> /dev/null43pip install -qq pyfluidsynth scipy44"""4546import os47import random48import tempfile4950import keras51import midi_neural_processor.processor as midi_tokenizer52import numpy as np53from keras import callbacks, layers, ops, optimizers, utils54from keras_hub import layers as hub_layers55from os import path5657"""58## Configuration5960Lets define the configuration for the model and the dataset to be used in this example.61"""62event_range = midi_tokenizer.RANGE_NOTE_ON63event_range += midi_tokenizer.RANGE_NOTE_OFF64event_range += midi_tokenizer.RANGE_TIME_SHIFT65event_range += midi_tokenizer.RANGE_VEL66CONFIG = utils.Config(67max_sequence_len=2048,68embedding_dim=256,69num_transformer_blocks=6,70batch_size=6,71token_pad=event_range,72token_start_of_sentence=event_range + 1,73token_end_of_sentence=event_range + 2,74vocabulary_size=event_range + 3,75model_out="tmp/music_transformer.keras",76seed=42,77)78utils.set_random_seed(CONFIG.seed)798081"""82## Maestro dataset8384The Maestro dataset contains MIDI files for piano performances.8586### Download the dataset8788We now download and extract the dataset, then move the MIDI files to a new directory.89"""909192def download_maestro(output_dir=None):93"""Download the Maestro MIDI dataset.94Extracted from: https://magenta.tensorflow.org/datasets/maestro95"""96# Ensure the output directory exists97output_dir = tempfile.mkdtemp() if output_dir is None else output_dir98os.makedirs(output_dir, exist_ok=True)99100# Download and extract zip file101dir = utils.get_file(102origin="https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip",103extract=True,104)105106# Gather all MIDI files107midi_files, file_paths = set(), list()108for root, _, files in os.walk(dir):109for file in files:110if file.lower().endswith(".midi") or file.lower().endswith(".mid"):111midi_files.add(path.join(root, file))112113# Move the files to the output directory114for file in sorted(midi_files):115file_paths.append(new_path := path.join(output_dir, path.basename(file)))116os.rename(file, new_path)117return file_paths118119120paths = list(sorted(download_maestro(output_dir="datasets/maestro")))121output_dir = path.dirname(paths[0])122123124"""125### Split the dataset126127We can now split the dataset into training and validation sets.128"""129130indices = np.random.permutation(len(paths))131split = int(len(paths) * 0.1)132train_paths = [paths[i] for i in indices[split:]]133val_paths = [paths[i] for i in indices[:split]]134135"""136### Hear a MIDI file137138We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio.139This allows us to listen to the data samples before and after processing.140141The following dependencies are required to play the audio:142- fluidsynth: `sudo apt install -y fluidsynth`143- pyfluidsynth, scipy: `pip install pyfluidsynth scipy`144"""145146147def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None):148import pretty_midi149from scipy.io.wavfile import write as write_wav150from IPython.display import Audio151152# Create the audio waveform153pretty_midi_file = pretty_midi.PrettyMIDI(midi_path)154waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate]155156# Display the audio if no path is provided157if out_dir is None:158# IPython display159return Audio(waveform, rate=sampling_rate)160161# Save the audio to a file162os.makedirs(out_dir, exist_ok=True)163audio_path = path.join(out_dir, path.basename(midi_path).split(".")[0] + ".wav")164write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16))165return audio_path166167168print(visualize_midi(train_paths[0], out_dir="tmp/")) # Saved audio path169visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook170171172"""173### Tokenize the data174175We now preprocess the MIDI files into a tokenized format for training.176"""177178179def encode_midi_task(midi_path):180"""Define a task that tokenizes a MIDI file."""181import midi_neural_processor.processor as midi_tokenizer182183return midi_tokenizer.encode_midi(midi_path)184185186def preprocess_midi_files(file_paths, save_dir=None):187"""Preprocess a list of MIDI files and save the notes to a file."""188from multiprocessing import Pool, cpu_count189190# Assume all files are in the same directory and save to the same directory191save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir192os.makedirs(save_dir, exist_ok=True)193194# Check if the notes have already been preprocessed195output_file = path.join(save_dir, "notes.npz")196if path.exists(output_file):197npz_file = np.load(output_file)198return [npz_file[key] for key in npz_file.keys()]199200# Preprocess the MIDI files in parallel201progbar = utils.Progbar(len(file_paths), unit_name="MIDI_file", interval=5)202pool = Pool(cpu_count() - 1)203all_notes = []204for notes in pool.imap_unordered(encode_midi_task, file_paths):205progbar.add(1)206all_notes.append(np.array(notes))207208# Save the notes to a file209np.savez(output_file, *all_notes)210return all_notes211212213train_midis = preprocess_midi_files(train_paths, path.join(output_dir, "train"))214val_midis = preprocess_midi_files(val_paths, path.join(output_dir, "val"))215216217"""218### Dataset objects219220We now define a dataset class that yields batches of input sequences and target sequences.221"""222223224class MidiDataset(utils.PyDataset):225"""A dataset for MIDI files that yields batches of input sequences and target sequences."""226227def __init__(228self,229encoded_midis,230batch_size=CONFIG.batch_size,231max_sequence_len=CONFIG.max_sequence_len,232):233super(MidiDataset, self).__init__()234self.batch_size = batch_size235self.max_sequence_len = max_sequence_len236self.encoded_midis = encoded_midis237batches, last_batch_size = divmod(len(encoded_midis), batch_size)238self._num_batches = batches + int(last_batch_size > 0)239240def __len__(self):241"""Get the number of batches."""242return self._num_batches243244def __getitem__(self, idx):245"""Generate random inputs and corresponding targets for the model."""246# Same as in the original paper, we always get a random batch.247# See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py248batch = random.sample(self.encoded_midis, k=self.batch_size)249250# Convert the batch to sequences251batch_data = [252self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch253]254batch_data = np.array(batch_data)255256# Split the data into input and target sequences257return batch_data[:, :-1], batch_data[:, 1:]258259def _get_sequence(self, data, max_length):260"""Get a random sequence of notes from a file."""261# Truncate or pad the sequence262if len(data) > max_length:263start = random.randrange(0, len(data) - max_length)264data = data[start : start + max_length]265elif len(data) < max_length:266data = np.append(data, CONFIG.token_end_of_sentence)267268# Pad the sequence if necessary269if len(data) < max_length:270data = np.concatenate(271(data, np.full(max_length - len(data), CONFIG.token_pad))272)273return np.asanyarray(data, dtype="int32")274275276train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis)277278279"""280## Model definition281282It is time to define the model architecture. We use a Transformer decoder283architecture with a custom attention mechanism, relative global attention.284285### Relative Global Attention286287The following code implements the Relative Global Attention layer. It is used288in place of the standard multi-head attention layer in the Transformer decoder.289The main difference is that it includes a relative positional encoding that290allows the model to learn relative positional information between tokens.291"""292293294@keras.utils.register_keras_serializable()295class RelativeGlobalAttention(layers.Layer):296"""297From Music Transformer (Huang et al., 2018)298https://arxiv.org/abs/1809.04281299"""300301def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs):302super().__init__(**kwargs)303self.key_length = None304self.max_sequence_len = max_sequence_len305self.relative_embedding = None306self.num_heads = num_heads307self.embedding_dim = embedding_dim308self.head_dim = embedding_dim // num_heads309self.query_dense = layers.Dense(int(self.embedding_dim))310self.key_dense = layers.Dense(int(self.embedding_dim))311self.value_dense = layers.Dense(int(self.embedding_dim))312self.output_dense = layers.Dense(embedding_dim, name="output")313314def build(self, input_shape):315self.query_length = input_shape[0][1]316self.key_length = input_shape[1][1]317self.relative_embedding = self.add_weight(318(self.max_sequence_len, int(self.head_dim)), name="relative_embedding"319)320321def _apply_dense_layer_and_split_heads(self, inputs, dense_layer):322# Apply linear transformation323inputs = dense_layer(inputs)324new_shape = ops.shape(inputs)325# Reshape to split by attention heads326reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1))327# Transpose for head-first format328return ops.transpose(reshaped, (0, 2, 1, 3))329330def call(self, inputs, mask=None):331# Compute Q, K, V: Batch, head, sequence, features332query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense)333key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense)334value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense)335336# Compute scaled dot-product attention scores337attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2]))338339# Compute relative positional encoding and combine with attention scores340start_idx = max(0, self.max_sequence_len - ops.shape(query)[2])341relative_embedding = self.relative_embedding[start_idx:, :]342attention_scores += self._compute_attention_scores(query, relative_embedding)343logits = attention_scores / ops.sqrt(self.head_dim)344345# Apply mask if provided346if mask is not None:347logits += ops.cast(mask, "float32") * -1e9348349# Compute attention weights350attention_weights = ops.nn.softmax(logits, axis=-1)351attention_output = ops.matmul(attention_weights, value)352353# Merge heads and apply final linear transformation354merged_attention = ops.transpose(attention_output, (0, 2, 1, 3))355merged_attention = ops.reshape(356merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim)357)358output = self.output_dense(merged_attention)359360return output, attention_weights361362def _compute_attention_scores(self, query, relative_embedding):363"""364Compute relative attention scores using positional encodings.365"""366relative_scores = ops.einsum("bhld, md->bhlm", query, relative_embedding)367relative_scores = self._apply_mask_to_relative_scores(relative_scores)368return self._skew_attention_scores(relative_scores)369370def _apply_mask_to_relative_scores(self, scores):371"""372Apply masking to relative positional scores to ignore future positions.373"""374mask = ops.flip(375ops.tri(scores.shape[-2], scores.shape[-1], dtype="float32"), axis=1376)377return mask * scores378379def _skew_attention_scores(self, scores):380"""381Perform skewing operation to align relative attention scores with the sequence.382"""383padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0)))384padded_shape = ops.shape(padded_scores)385reshaped_scores = ops.reshape(386padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2])387)388skewed_scores = reshaped_scores[:, :, 1:, :]389390if self.key_length > self.query_length:391size_diff = self.key_length - self.query_length392return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]])393else:394return skewed_scores[:, :, :, : self.key_length]395396397"""398### Decoder Layer399400Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like401the standard Transformer decoder layer but with the custom attention mechanism.402"""403404405@keras.utils.register_keras_serializable()406class DecoderLayer(layers.Layer):407def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1):408super(DecoderLayer, self).__init__()409410# Initialize attributes411self.embedding_dim = embedding_dim412self.num_heads = num_heads413self.max_sequence_len = max_sequence_len414415# Initialize layers416self.relative_global_attention_1 = RelativeGlobalAttention(417num_heads, embedding_dim, max_sequence_len418)419420self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, "relu")421self.feed_forward_network_pos = layers.Dense(self.embedding_dim)422423self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6)424self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6)425426self.dropout_1 = layers.Dropout(dropout)427self.dropout_2 = layers.Dropout(dropout)428429def call(self, inputs, mask=None, training=False):430# Attention block. Inputs are (query, key, value)431attention_out, attention_weights = self.relative_global_attention_1(432(inputs, inputs, inputs), mask=mask433)434attention_out = self.dropout_1(attention_out, training=training)435attention_out_normalized = self.layer_normalization_1(attention_out + inputs)436437ffn_out = self.feed_forward_network_pre(attention_out)438ffn_out = self.feed_forward_network_pos(ffn_out)439ffn_out = self.dropout_2(ffn_out, training=training)440out = self.layer_normalization_2(attention_out_normalized + ffn_out)441442return out, attention_weights443444445"""446### Decoder447448The Decoder layer is composed of multiple DecoderLayer blocks. It also includes449an embedding layer that converts our tokenized input into an embedding representation.450"""451452453@keras.utils.register_keras_serializable()454class Decoder(layers.Layer):455def __init__(456self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout457):458super(Decoder, self).__init__()459460self.embedding_dim = embedding_dim461self.num_blocks = num_blocks462463self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim)464self.positional_encoding = hub_layers.SinePositionEncoding()465466self.decode_layers = [467DecoderLayer(468embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout469)470for _ in range(num_blocks)471]472self.dropout = layers.Dropout(dropout)473474def call(self, inputs, mask=None, training=False, return_attention_weights=False):475weights = []476477# Adding embedding and position encoding.478x = self.embedding(inputs)479x = x * ops.sqrt(ops.cast(self.embedding_dim, "float32"))480x = x + self.positional_encoding(x)481x = self.dropout(x, training=training)482483# Passing through the transformer blocks.484for i in range(self.num_blocks):485x, w = self.decode_layers[i](x, mask=mask, training=training)486weights.append(w)487if return_attention_weights:488return x, weights489return x490491492"""493### Music Transformer Decoder494495With the above layers defined, we can now define the MusicTransformerDecoder model. It applies496a linear transformation to the output of the decoder to get the logits for each token.497"""498499500@keras.utils.register_keras_serializable()501class MusicTransformerDecoder(keras.Model):502def __init__(503self,504embedding_dim=CONFIG.embedding_dim,505vocabulary_size=CONFIG.vocabulary_size,506num_blocks=CONFIG.num_transformer_blocks,507max_sequence_len=CONFIG.max_sequence_len,508dropout=0.2,509):510# Initialize attributes511super(MusicTransformerDecoder, self).__init__()512self.embedding_dim = embedding_dim513self.vocabulary_size = vocabulary_size514self.num_blocks = num_blocks515self.max_sequence_len = max_sequence_len516517# Initialize layers518# Transformer decoder519self.decoder = Decoder(520embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout521)522# Output layer523self.fc = layers.Dense(self.vocabulary_size, activation=None, name="output")524525@staticmethod526def get_look_ahead_mask(max_sequence_len, inputs):527sequence_length = min(max_sequence_len, inputs.shape[1])528sequence_mask = ops.logical_not(529ops.tri(sequence_length, sequence_length, dtype="bool")530)531532inputs = ops.cast(inputs[:, None, None, :], "int32")533output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad534decoder_output_mask = ops.equal(inputs, output_pad_tensor)535return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), "int32")536537def call(self, inputs, training=False):538mask = self.get_look_ahead_mask(self.max_sequence_len, inputs)539decoding = self.decoder(540inputs, mask=mask, training=training, return_attention_weights=False541)542return self.fc(decoding)543544# --- Sequence generation methods545546def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5):547inputs = ops.convert_to_tensor([inputs])548549# Generate a new token using output distribution at given index550def generate_token(inputs, end_idx):551distribution = ops.stop_gradient(self.call(inputs)[0, end_idx])552553# Select the top-k tokens and their probabilities554top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k)555556# Sample from the top-k probabilities557new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1)558return ops.take(top_k_indices, new_token_idx[0])559560# Compute the number of tokens to add561added_tokens = min(length, self.max_sequence_len - inputs.shape[1])562progbar = utils.Progbar(added_tokens, unit_name="token", interval=5)563564# Pad the input sequence that will be filled with generated tokens565out = ops.pad(inputs, ((0, 0), (0, added_tokens)), "constant", CONFIG.token_pad)566567# Generate tokens using top-k sampling568for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens):569token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype)570out = ops.scatter_update(out, ((0, token_idx + 1),), token)571progbar.add(1)572573return ops.convert_to_numpy(out[0])574575# --- Serialization methods576577def get_config(self):578atts = ["embedding_dim", "vocabulary_size", "num_blocks", "max_sequence_len"]579return {a: getattr(self, a) for a in atts}580581@classmethod582def from_config(cls, config):583return cls(**config)584585586"""587### Loss function588589We define a custom loss function that computes the categorical cross-entropy590loss for the model. It is computed only for non-padding tokens and uses591`from_logits=True` since the model outputs logits.592"""593594595@keras.utils.register_keras_serializable()596def train_loss(y_true, y_pred):597mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), "float32")598y_true = ops.one_hot(ops.cast(y_true, "int32"), CONFIG.vocabulary_size)599return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask600601602"""603### Learning rate schedule604605Following the Music Transformer paper, we define an adapted exponential decay606learning rate schedule that takes into account the embedding dimension.607"""608609610@keras.utils.register_keras_serializable()611class CustomSchedule(optimizers.schedules.LearningRateSchedule):612def __init__(self, embedding_dim, warmup_steps=4000):613super(CustomSchedule, self).__init__()614615self.embedding_dim = embedding_dim616self.warmup_steps = warmup_steps617618self._embedding_dim = ops.cast(self.embedding_dim, "float32")619# Numerical stability adjustment on torch, which is less precise620self._lr_adjust = 0.1 if keras.backend.backend() == "torch" else 1.0621622def get_config(self):623return {"embedding_dim": self.embedding_dim, "warmup_steps": self.warmup_steps}624625def __call__(self, step):626step_rsqrt = ops.rsqrt(ops.cast(step, "float32"))627warmup_adjust = step * (self.warmup_steps**-1.5)628output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust)629return self._lr_adjust * output630631632"""633## Training the model634635We can now train the model on the Maestro dataset. First, we define a training636function. This function compiles the model, trains it, and saves the best model637checkpoint. This way, we can continue training from the best model checkpoint638if needed.639"""640641642def train_model(model, train_ds, val_ds, epochs=15):643# Configure optimizer644learning_rate = CustomSchedule(CONFIG.embedding_dim)645optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)646647# Compile the model648model.compile(optimizer=optimizer, loss=train_loss)649650# Train the model651save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True)652model.fit(653train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2654)655return model656657658"""659We can now train the model on the Maestro dataset. If a model checkpoint exists,660we can load it and continue training.661"""662if path.exists(CONFIG.model_out):663model = keras.models.load_model(CONFIG.model_out)664# Comment out to continue model training from the checkpoint665# train_model(model, train_dataset, val_dataset, epochs=10)666else:667# Train the model668model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset)669670671"""672## Generate music673674We can now generate music using the trained model. We use an existing MIDI file675as a seed and generate a new sequence.676"""677678679def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None):680# Ensure the output directory exists681out_dir = out_dir if out_dir is not None else tempfile.mkdtemp()682os.makedirs(out_dir, exist_ok=True)683684# Get some tokens from the MIDI file685inputs = midi_tokenizer.encode_midi(seed_path)[100:125]686print(f"Seed tokens: {inputs}")687688# Generate music that follows the input tokens until the maximum length689result = model.generate(inputs, length=length, top_k=top_k)690691output_path = path.join(out_dir, path.basename(seed_path).split(".")[0] + ".mid")692midi_tokenizer.decode_midi(result, output_path)693return output_path694695696output_file = generate_music(model, val_paths[-1], out_dir="tmp/", top_k=15)697print(visualize_midi(output_file, out_dir="tmp/")) # Saved audio path698visualize_midi(output_file) # Display the audio if in a Jupyter notebook699700"""701## Conclusion702703In this example, we learned how to build a music generation model using a custom704Transformer decoder architecture.705706We did it following the Music Transformer paper by Huang et al. (2018).707To do so we had to:708709- Define a custom loss function and learning rate schedule.710- Define a custom attention mechanism.711- Preprocess MIDI files into a tokenized format.712713After training the model on the Maestro dataset, we generated music sequences714using a seed MIDI file.715716### Next steps717718We could further improve inference times by caching attention weights during the719forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the720`CachedMultiHeadAttention` layer.721"""722723724