Path: blob/master/examples/audio/vocal_track_separation.py
3507 views
"""1Title: Vocal Track Separation with Encoder-Decoder Architecture2Author: [Joaquin Jimenez](https://github.com/johacks/)3Date created: 2024/12/104Last modified: 2024/12/105Description: Train a model to separate vocal tracks from music mixtures.6Accelerator: GPU7"""89"""10## Introduction1112In this tutorial, we build a vocal track separation model using an encoder-decoder13architecture in Keras 3.1415We train the model on the [MUSDB18 dataset](https://doi.org/10.5281/zenodo.1117372),16which provides music mixtures and isolated tracks for drums, bass, other, and vocals.1718Key concepts covered:1920- Audio data preprocessing using the Short-Time Fourier Transform (STFT).21- Audio data augmentation techniques.22- Implementing custom encoders and decoders specialized for audio data.23- Defining appropriate loss functions and metrics for audio source separation tasks.2425The model architecture is derived from the TFC_TDF_Net model described in:2627W. Choi, M. Kim, J. Chung, D. Lee, and S. Jung, “Investigating U-Nets with various28intermediate blocks for spectrogram-based singing voice separation,” in the 21st29International Society for Music Information Retrieval Conference, 2020.3031For reference code, see:32[GitHub: ws-choi/ISMIR2020_U_Nets_SVS](https://github.com/ws-choi/ISMIR2020_U_Nets_SVS).3334The data processing and model training routines are partly derived from:35[ZFTurbo/Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main).36"""3738"""39## Setup4041Import and install all the required dependencies.42"""4344"""shell45pip install -qq audiomentations soundfile ffmpeg-binaries46pip install -qq "keras==3.7.0"47sudo -n apt-get install -y graphviz >/dev/null 2>&1 # Required for plotting the model48"""4950import glob51import os5253os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"5455import random56import subprocess57import tempfile58import typing59from os import path6061import audiomentations as aug62import ffmpeg63import keras64import numpy as np65import soundfile as sf66from IPython import display67from keras import callbacks, layers, ops, saving68from matplotlib import pyplot as plt6970"""71## Configuration7273The following constants define configuration parameters for audio processing74and model training, including dataset paths, audio chunk sizes, Short-Time Fourier75Transform (STFT) parameters, and training hyperparameters.76"""7778# MUSDB18 dataset configuration79MUSDB_STREAMS = {"mixture": 0, "drums": 1, "bass": 2, "other": 3, "vocals": 4}80TARGET_INSTRUMENTS = {track: MUSDB_STREAMS[track] for track in ("vocals",)}81N_INSTRUMENTS = len(TARGET_INSTRUMENTS)82SOURCE_INSTRUMENTS = tuple(k for k in MUSDB_STREAMS if k != "mixture")8384# Audio preprocessing parameters for Short-Time Fourier Transform (STFT)85N_SUBBANDS = 4 # Number of subbands into which frequencies are split86CHUNK_SIZE = 65024 # Number of amplitude samples per audio chunk (~4 seconds)87STFT_N_FFT = 2048 # FFT points used in STFT88STFT_HOP_LENGTH = 512 # Hop length for STFT8990# Training hyperparameters91N_CHANNELS = 64 # Base channel count for the model92BATCH_SIZE = 393ACCUMULATION_STEPS = 294EFFECTIVE_BATCH_SIZE = BATCH_SIZE * (ACCUMULATION_STEPS or 1)9596# Paths97TMP_DIR = path.expanduser("~/.keras/tmp")98DATASET_DIR = path.expanduser("~/.keras/datasets")99MODEL_PATH = path.join(TMP_DIR, f"model_{keras.backend.backend()}.keras")100CSV_LOG_PATH = path.join(TMP_DIR, f"training_{keras.backend.backend()}.csv")101os.makedirs(DATASET_DIR, exist_ok=True)102os.makedirs(TMP_DIR, exist_ok=True)103104# Set random seed for reproducibility105keras.utils.set_random_seed(21)106107"""108## MUSDB18 Dataset109110The MUSDB18 dataset is a standard benchmark for music source separation, containing111150 full-length music tracks along with isolated drums, bass, other, and vocals.112The dataset is stored in .mp4 format, and each .mp4 file includes multiple audio113streams (mixture and individual tracks).114115### Download and Conversion116117The following utility function downloads MUSDB18 and converts its .mp4 files to118.wav files for each instrument track, resampled to 16 kHz.119"""120121122def download_musdb18(out_dir=None):123"""Download and extract the MUSDB18 dataset, then convert .mp4 files to .wav files.124125MUSDB18 reference:126Rafii, Z., Liutkus, A., Stöter, F.-R., Mimilakis, S. I., & Bittner, R. (2017).127MUSDB18 - a corpus for music separation (1.0.0) [Data set]. Zenodo.128"""129ffmpeg.init()130from ffmpeg import FFMPEG_PATH131132# Create output directories133os.makedirs((base := out_dir or tempfile.mkdtemp()), exist_ok=True)134if path.exists((out_dir := path.join(base, "musdb18_wav"))):135print("MUSDB18 dataset already downloaded")136return out_dir137138# Download and extract the dataset139download_dir = keras.utils.get_file(140fname="musdb18",141origin="https://zenodo.org/records/1117372/files/musdb18.zip",142extract=True,143)144145# ffmpeg command template: input, stream index, output146ffmpeg_args = str(FFMPEG_PATH) + " -v error -i {} -map 0:{} -vn -ar 16000 {}"147148# Convert each mp4 file to multiple .wav files for each track149for split in ("train", "test"):150songs = os.listdir(path.join(download_dir, split))151for i, song in enumerate(songs):152if i % 10 == 0:153print(f"{split.capitalize()}: {i}/{len(songs)} songs processed")154155mp4_path_orig = path.join(download_dir, split, song)156mp4_path = path.join(tempfile.mkdtemp(), split, song.replace(" ", "_"))157os.makedirs(path.dirname(mp4_path), exist_ok=True)158os.rename(mp4_path_orig, mp4_path)159160wav_dir = path.join(out_dir, split, path.basename(mp4_path).split(".")[0])161os.makedirs(wav_dir, exist_ok=True)162163for track in SOURCE_INSTRUMENTS:164out_path = path.join(wav_dir, f"{track}.wav")165stream_index = MUSDB_STREAMS[track]166args = ffmpeg_args.format(mp4_path, stream_index, out_path).split()167assert subprocess.run(args).returncode == 0, "ffmpeg conversion failed"168return out_dir169170171# Download and prepare the MUSDB18 dataset172songs = download_musdb18(out_dir=DATASET_DIR)173174"""175### Custom Dataset176177We define a custom dataset class to generate random audio chunks and their corresponding178labels. The dataset does the following:1791801. Selects a random chunk from a random song and instrument.1812. Applies optional data augmentations.1823. Combines isolated tracks to form new synthetic mixtures.1834. Prepares features (mixtures) and labels (vocals) for training.184185This approach allows creating an effectively infinite variety of training examples186through randomization and augmentation.187"""188189190class Dataset(keras.utils.PyDataset):191def __init__(192self,193songs,194batch_size=BATCH_SIZE,195chunk_size=CHUNK_SIZE,196batches_per_epoch=1000 * ACCUMULATION_STEPS,197augmentation=True,198**kwargs,199):200super().__init__(**kwargs)201self.augmentation = augmentation202self.vocals_augmentations = [203aug.PitchShift(min_semitones=-5, max_semitones=5, p=0.1),204aug.SevenBandParametricEQ(-9, 9, p=0.25),205aug.TanhDistortion(0.1, 0.7, p=0.1),206]207self.other_augmentations = [208aug.PitchShift(p=0.1),209aug.AddGaussianNoise(p=0.1),210]211self.songs = songs212self.sizes = {song: self.get_track_set_size(song) for song in self.songs}213self.batch_size = batch_size214self.chunk_size = chunk_size215self.batches_per_epoch = batches_per_epoch216217def get_track_set_size(self, song: str):218"""Return the smallest track length in the given song directory."""219sizes = [len(sf.read(p)[0]) for p in glob.glob(path.join(song, "*.wav"))]220if max(sizes) != min(sizes):221print(f"Warning: {song} has different track lengths")222return min(sizes)223224def random_chunk_of_instrument_type(self, instrument: str):225"""Extract a random chunk for the specified instrument from a random song."""226song, size = random.choice(list(self.sizes.items()))227track = path.join(song, f"{instrument}.wav")228229if self.chunk_size <= size:230start = np.random.randint(size - self.chunk_size + 1)231audio = sf.read(track, self.chunk_size, start, dtype="float32")[0]232audio_mono = np.mean(audio, axis=1)233else:234# If the track is shorter than chunk_size, pad the signal235audio_mono = np.mean(sf.read(track, dtype="float32")[0], axis=1)236audio_mono = np.pad(audio_mono, ((0, self.chunk_size - size),))237238# If the chunk is almost silent, retry239if np.mean(np.abs(audio_mono)) < 0.01:240return self.random_chunk_of_instrument_type(instrument)241242return self.data_augmentation(audio_mono, instrument)243244def data_augmentation(self, audio: np.ndarray, instrument: str):245"""Apply data augmentation to the audio chunk, if enabled."""246247def coin_flip(x, probability: float, fn: typing.Callable):248return fn(x) if random.uniform(0, 1) < probability else x249250if self.augmentation:251augmentations = (252self.vocals_augmentations253if instrument == "vocals"254else self.other_augmentations255)256# Loudness augmentation257audio *= np.random.uniform(0.5, 1.5, (len(audio),)).astype("float32")258# Random reverse259audio = coin_flip(audio, 0.1, lambda x: np.flip(x))260# Random polarity inversion261audio = coin_flip(audio, 0.5, lambda x: -x)262# Apply selected augmentations263for aug_ in augmentations:264aug_.randomize_parameters(audio, sample_rate=16000)265audio = aug_(audio, sample_rate=16000)266return audio267268def random_mix_of_tracks(self) -> dict:269"""Create a random mix of instruments by summing their individual chunks."""270tracks = {}271for instrument in SOURCE_INSTRUMENTS:272# Start with a single random chunk273mixup = [self.random_chunk_of_instrument_type(instrument)]274275# Randomly add more chunks of the same instrument (mixup augmentation)276if self.augmentation:277for p in (0.2, 0.02):278if random.uniform(0, 1) < p:279mixup.append(self.random_chunk_of_instrument_type(instrument))280281tracks[instrument] = np.mean(mixup, axis=0, dtype="float32")282return tracks283284def __len__(self):285return self.batches_per_epoch286287def __getitem__(self, idx):288# Generate a batch of random mixtures289batch = [self.random_mix_of_tracks() for _ in range(self.batch_size)]290291# Features: sum of all tracks292batch_x = ops.sum(293np.array([list(track_set.values()) for track_set in batch]), axis=1294)295296# Labels: isolated target instruments (e.g., vocals)297batch_y = np.array(298[[track_set[t] for t in TARGET_INSTRUMENTS] for track_set in batch]299)300301return batch_x, ops.convert_to_tensor(batch_y)302303304# Create train and validation datasets305train_ds = Dataset(glob.glob(path.join(songs, "train", "*")))306val_ds = Dataset(307glob.glob(path.join(songs, "test", "*")),308batches_per_epoch=int(0.1 * train_ds.batches_per_epoch),309augmentation=False,310)311312"""313### Visualize a Sample314315Let's visualize a random mixed audio chunk and its corresponding isolated vocals.316This helps to understand the nature of the preprocessed input data.317"""318319320def visualize_audio_np(audio: np.ndarray, rate=16000, name="mixup"):321"""Plot and display an audio waveform and also produce an Audio widget."""322plt.figure(figsize=(10, 6))323plt.plot(audio)324plt.title(f"Waveform: {name}")325plt.xlim(0, len(audio))326plt.ylabel("Amplitude")327plt.show()328# plt.savefig(f"tmp/{name}.png")329330# Normalize and display audio331audio_norm = (audio - np.min(audio)) / (np.max(audio) - np.min(audio) + 1e-8)332audio_norm = (audio_norm * 2 - 1) * 0.6333display.display(display.Audio(audio_norm, rate=rate))334# sf.write(f"tmp/{name}.wav", audio_norm, rate)335336337sample_batch_x, sample_batch_y = val_ds[None] # Random batch338visualize_audio_np(ops.convert_to_numpy(sample_batch_x[0]))339visualize_audio_np(ops.convert_to_numpy(sample_batch_y[0, 0]), name="vocals")340341"""342## Model343344### Preprocessing345346The model operates on STFT representations rather than raw audio. We define a347preprocessing model to compute STFT and a corresponding inverse transform (iSTFT).348"""349350351def stft(inputs, fft_size=STFT_N_FFT, sequence_stride=STFT_HOP_LENGTH):352"""Compute the STFT for the input audio and return the real and imaginary parts."""353real_x, imag_x = ops.stft(inputs, fft_size, sequence_stride, fft_size)354real_x, imag_x = ops.expand_dims(real_x, -1), ops.expand_dims(imag_x, -1)355x = ops.concatenate((real_x, imag_x), axis=-1)356357# Drop last freq sample for convenience358return ops.split(x, [x.shape[2] - 1], axis=2)[0]359360361def inverse_stft(inputs, fft_size=STFT_N_FFT, sequence_stride=STFT_HOP_LENGTH):362"""Compute the inverse STFT for the given STFT input."""363x = inputs364365# Pad back dropped freq sample if using torch backend366if keras.backend.backend() == "torch":367x = ops.pad(x, ((0, 0), (0, 0), (0, 1), (0, 0)))368369real_x, imag_x = ops.split(x, 2, axis=-1)370real_x = ops.squeeze(real_x, axis=-1)371imag_x = ops.squeeze(imag_x, axis=-1)372373return ops.istft((real_x, imag_x), fft_size, sequence_stride, fft_size)374375376"""377### Model Architecture378379The model uses a custom encoder-decoder architecture with Time-Frequency Convolution380(TFC) and Time-Distributed Fully Connected (TDF) blocks. They are grouped into a381`TimeFrequencyTransformBlock`, i.e. "TFC_TDF" in the original paper by Choi et al.382383We then define an encoder-decoder network with multiple scales. Each encoder scale384applies TFC_TDF blocks followed by downsampling, while decoder scales apply TFC_TDF385blocks over the concatenation of upsampled features and associated encoder outputs.386"""387388389@saving.register_keras_serializable()390class TimeDistributedDenseBlock(layers.Layer):391"""Time-Distributed Fully Connected layer block.392393Applies frequency-wise dense transformations across time frames with instance394normalization and GELU activation.395"""396397def __init__(self, bottleneck_factor, fft_dim, **kwargs):398super().__init__(**kwargs)399self.fft_dim = fft_dim400self.hidden_dim = fft_dim // bottleneck_factor401402def build(self, *_):403self.group_norm_1 = layers.GroupNormalization(groups=-1)404self.group_norm_2 = layers.GroupNormalization(groups=-1)405self.dense_1 = layers.Dense(self.hidden_dim, use_bias=False)406self.dense_2 = layers.Dense(self.fft_dim, use_bias=False)407408def call(self, x):409# Apply normalization and dense layers frequency-wise410x = ops.gelu(self.group_norm_1(x))411x = ops.swapaxes(x, -1, -2)412x = self.dense_1(x)413414x = ops.gelu(self.group_norm_2(ops.swapaxes(x, -1, -2)))415x = ops.swapaxes(x, -1, -2)416x = self.dense_2(x)417return ops.swapaxes(x, -1, -2)418419420@saving.register_keras_serializable()421class TimeFrequencyConvolution(layers.Layer):422"""Time-Frequency Convolutional layer.423424Applies a 2D convolution over time-frequency representations and applies instance425normalization and GELU activation.426"""427428def __init__(self, channels, **kwargs):429super().__init__(**kwargs)430self.channels = channels431432def build(self, *_):433self.group_norm = layers.GroupNormalization(groups=-1)434self.conv = layers.Conv2D(self.channels, 3, padding="same", use_bias=False)435436def call(self, x):437return self.conv(ops.gelu(self.group_norm(x)))438439440@saving.register_keras_serializable()441class TimeFrequencyTransformBlock(layers.Layer):442"""Implements TFC_TDF block for encoder-decoder architecture.443444Repeatedly apply Time-Frequency Convolution and Time-Distributed Dense blocks as445many times as specified by the `length` parameter.446"""447448def __init__(449self, channels, length, fft_dim, bottleneck_factor, in_channels=None, **kwargs450):451super().__init__(**kwargs)452self.channels = channels453self.length = length454self.fft_dim = fft_dim455self.bottleneck_factor = bottleneck_factor456self.in_channels = in_channels or channels457458def build(self, *_):459self.blocks = []460# Add blocks in a flat list to avoid nested structures461for i in range(self.length):462in_channels = self.channels if i > 0 else self.in_channels463self.blocks.append(TimeFrequencyConvolution(in_channels))464self.blocks.append(465TimeDistributedDenseBlock(self.bottleneck_factor, self.fft_dim)466)467self.blocks.append(TimeFrequencyConvolution(self.channels))468# Residual connection469self.blocks.append(layers.Conv2D(self.channels, 1, 1, use_bias=False))470471def call(self, inputs):472x = inputs473# Each block consists of 4 layers:474# 1. Time-Frequency Convolution475# 2. Time-Distributed Dense476# 3. Time-Frequency Convolution477# 4. Residual connection478for i in range(0, len(self.blocks), 4):479tfc_1 = self.blocks[i](x)480tdf = self.blocks[i + 1](x)481tfc_2 = self.blocks[i + 2](tfc_1 + tdf)482x = tfc_2 + self.blocks[i + 3](x) # Residual connection483return x484485486@saving.register_keras_serializable()487class Downscale(layers.Layer):488"""Downscale time-frequency dimensions using a convolution."""489490conv_cls = layers.Conv2D491492def __init__(self, channels, scale, **kwargs):493super().__init__(**kwargs)494self.channels = channels495self.scale = scale496497def build(self, *_):498self.conv = self.conv_cls(self.channels, self.scale, self.scale, use_bias=False)499self.norm = layers.GroupNormalization(groups=-1)500501def call(self, inputs):502return self.norm(ops.gelu(self.conv(inputs)))503504505@saving.register_keras_serializable()506class Upscale(Downscale):507"""Upscale time-frequency dimensions using a transposed convolution."""508509conv_cls = layers.Conv2DTranspose510511512def build_model(513inputs,514n_instruments=N_INSTRUMENTS,515n_subbands=N_SUBBANDS,516channels=N_CHANNELS,517fft_dim=(STFT_N_FFT // 2) // N_SUBBANDS,518n_scales=4,519scale=(2, 2),520block_size=2,521growth=128,522bottleneck_factor=2,523**kwargs,524):525"""Build the TFC_TDF encoder-decoder model for source separation."""526# Compute STFT527x = stft(inputs)528529# Split mixture into subbands as separate channels530mix = ops.reshape(x, (-1, x.shape[1], x.shape[2] // n_subbands, 2 * n_subbands))531first_conv_out = layers.Conv2D(channels, 1, 1, use_bias=False)(mix)532x = first_conv_out533534# Encoder path535encoder_outs = []536for _ in range(n_scales):537x = TimeFrequencyTransformBlock(538channels, block_size, fft_dim, bottleneck_factor539)(x)540encoder_outs.append(x)541fft_dim, channels = fft_dim // scale[0], channels + growth542x = Downscale(channels, scale)(x)543544# Bottleneck545x = TimeFrequencyTransformBlock(channels, block_size, fft_dim, bottleneck_factor)(x)546547# Decoder path548for _ in range(n_scales):549fft_dim, channels = fft_dim * scale[0], channels - growth550x = ops.concatenate([Upscale(channels, scale)(x), encoder_outs.pop()], axis=-1)551x = TimeFrequencyTransformBlock(552channels, block_size, fft_dim, bottleneck_factor, in_channels=x.shape[-1]553)(x)554555# Residual connection and final convolutions556x = ops.concatenate([mix, x * first_conv_out], axis=-1)557x = layers.Conv2D(channels, 1, 1, use_bias=False, activation="gelu")(x)558x = layers.Conv2D(n_instruments * n_subbands * 2, 1, 1, use_bias=False)(x)559560# Reshape back to instrument-wise STFT561x = ops.reshape(x, (-1, x.shape[1], x.shape[2] * n_subbands, n_instruments, 2))562x = ops.transpose(x, (0, 3, 1, 2, 4))563x = ops.reshape(x, (-1, n_instruments, x.shape[2], x.shape[3] * 2))564565return keras.Model(inputs=inputs, outputs=x, **kwargs)566567568"""569## Loss and Metrics570571We define:572573- `spectral_loss`: Mean absolute error in STFT domain.574- `sdr`: Signal-to-Distortion Ratio, a common source separation metric.575"""576577578def prediction_to_wave(x, n_instruments=N_INSTRUMENTS):579"""Convert STFT predictions back to waveform."""580x = ops.reshape(x, (-1, x.shape[2], x.shape[3] // 2, 2))581x = inverse_stft(x)582return ops.reshape(x, (-1, n_instruments, x.shape[1]))583584585def target_to_stft(y):586"""Convert target waveforms to their STFT representations."""587y = ops.reshape(y, (-1, CHUNK_SIZE))588y_real, y_imag = ops.stft(y, STFT_N_FFT, STFT_HOP_LENGTH, STFT_N_FFT)589y_real, y_imag = y_real[..., :-1], y_imag[..., :-1]590y = ops.stack([y_real, y_imag], axis=-1)591return ops.reshape(y, (-1, N_INSTRUMENTS, y.shape[1], y.shape[2] * 2))592593594@saving.register_keras_serializable()595def sdr(y_true, y_pred):596"""Signal-to-Distortion Ratio metric."""597y_pred = prediction_to_wave(y_pred)598# Add epsilon for numerical stability599num = ops.sum(ops.square(y_true), axis=-1) + 1e-8600den = ops.sum(ops.square(y_true - y_pred), axis=-1) + 1e-8601return 10 * ops.log10(num / den)602603604@saving.register_keras_serializable()605def spectral_loss(y_true, y_pred):606"""Mean absolute error in the STFT domain."""607y_true = target_to_stft(y_true)608return ops.mean(ops.absolute(y_true - y_pred))609610611"""612## Training613614### Visualize Model Architecture615"""616617# Load or create the model618if path.exists(MODEL_PATH):619model = saving.load_model(MODEL_PATH)620else:621model = build_model(keras.Input(sample_batch_x.shape[1:]), name="tfc_tdf_net")622623# Display the model architecture624model.summary()625img = keras.utils.plot_model(model, path.join(TMP_DIR, "model.png"), show_shapes=True)626display.display(img)627628"""629### Compile and Train the Model630"""631632# Compile the model633optimizer = keras.optimizers.Adam(5e-05, gradient_accumulation_steps=ACCUMULATION_STEPS)634model.compile(optimizer=optimizer, loss=spectral_loss, metrics=[sdr])635636# Define callbacks637cbs = [638callbacks.ModelCheckpoint(MODEL_PATH, "val_sdr", save_best_only=True, mode="max"),639callbacks.ReduceLROnPlateau(factor=0.95, patience=2),640callbacks.CSVLogger(CSV_LOG_PATH),641]642643if not path.exists(MODEL_PATH):644model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=cbs, shuffle=False)645else:646# Demonstration of a single epoch of training when model already exists647model.fit(train_ds, validation_data=val_ds, epochs=1, shuffle=False, verbose=2)648649"""650## Evaluation651652Evaluate the model on the validation dataset and visualize predicted vocals.653"""654655model.evaluate(val_ds, verbose=2)656y_pred = model.predict(sample_batch_x, verbose=2)657y_pred = prediction_to_wave(y_pred)658visualize_audio_np(ops.convert_to_numpy(y_pred[0, 0]), name="vocals_pred")659660"""661## Conclusion662663We built and trained a vocal track separation model using an encoder-decoder664architecture with custom blocks applied to the MUSDB18 dataset. We demonstrated665STFT-based preprocessing, data augmentation, and a source separation metric (SDR).666667**Next steps:**668669- Train for more epochs and refine hyperparameters.670- Separate multiple instruments simultaneously.671- Enhance the model to handle instruments not present in the mixture.672"""673674675