Path: blob/master/examples/generative/ipynb/midi_generation_with_transformer.ipynb
3508 views
Music Generation with Transformer Models
Author: Joaquin Jimenez
Date created: 2024/11/22
Last modified: 2024/11/26
Description: Use a Transformer model to train on MIDI data and generate music sequences.
Introduction
In this tutorial, we learn how to build a music generation model using a Transformer decode-only architecture. The model is trained on the Maestro dataset and implemented using keras 3. In the process, we explore MIDI tokenization, and relative global attention mechanisms.
This example is based on the paper "Music Transformer" by Huang et al. (2018). Check out the original paper and code.
Setup
Before we start, let's import and install all the libraries we need.
Optional dependencies
To hear the audio, install the following additional dependencies:
Configuration
Lets define the configuration for the model and the dataset to be used in this example.
Maestro dataset
The Maestro dataset contains MIDI files for piano performances.
Download the dataset
We now download and extract the dataset, then move the MIDI files to a new directory.
Split the dataset
We can now split the dataset into training and validation sets.
Hear a MIDI file
We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio. This allows us to listen to the data samples before and after processing.
The following dependencies are required to play the audio:
fluidsynth:
sudo apt install -y fluidsynth
pyfluidsynth, scipy:
pip install pyfluidsynth scipy
Tokenize the data
We now preprocess the MIDI files into a tokenized format for training.
Dataset objects
We now define a dataset class that yields batches of input sequences and target sequences.
Model definition
It is time to define the model architecture. We use a Transformer decoder architecture with a custom attention mechanism, relative global attention.
Relative Global Attention
The following code implements the Relative Global Attention layer. It is used in place of the standard multi-head attention layer in the Transformer decoder. The main difference is that it includes a relative positional encoding that allows the model to learn relative positional information between tokens.
Decoder Layer
Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like the standard Transformer decoder layer but with the custom attention mechanism.
Decoder
The Decoder layer is composed of multiple DecoderLayer blocks. It also includes an embedding layer that converts our tokenized input into an embedding representation.
Music Transformer Decoder
With the above layers defined, we can now define the MusicTransformerDecoder model. It applies a linear transformation to the output of the decoder to get the logits for each token.
Loss function
We define a custom loss function that computes the categorical cross-entropy loss for the model. It is computed only for non-padding tokens and uses from_logits=True
since the model outputs logits.
Learning rate schedule
Following the Music Transformer paper, we define an adapted exponential decay learning rate schedule that takes into account the embedding dimension.
Training the model
We can now train the model on the Maestro dataset. First, we define a training function. This function compiles the model, trains it, and saves the best model checkpoint. This way, we can continue training from the best model checkpoint if needed.
We can now train the model on the Maestro dataset. If a model checkpoint exists, we can load it and continue training.
Generate music
We can now generate music using the trained model. We use an existing MIDI file as a seed and generate a new sequence.
Conclusion
In this example, we learned how to build a music generation model using a custom Transformer decoder architecture.
We did it following the Music Transformer paper by Huang et al. (2018). To do so we had to:
Define a custom loss function and learning rate schedule.
Define a custom attention mechanism.
Preprocess MIDI files into a tokenized format.
After training the model on the Maestro dataset, we generated music sequences using a seed MIDI file.
Next steps
We could further improve inference times by caching attention weights during the forward pass, in a similar way as keras_hub
CausalLM
models, which use the CachedMultiHeadAttention
layer.