Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/vocal_track_separation.py
3507 views
1
"""
2
Title: Vocal Track Separation with Encoder-Decoder Architecture
3
Author: [Joaquin Jimenez](https://github.com/johacks/)
4
Date created: 2024/12/10
5
Last modified: 2024/12/10
6
Description: Train a model to separate vocal tracks from music mixtures.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this tutorial, we build a vocal track separation model using an encoder-decoder
14
architecture in Keras 3.
15
16
We train the model on the [MUSDB18 dataset](https://doi.org/10.5281/zenodo.1117372),
17
which provides music mixtures and isolated tracks for drums, bass, other, and vocals.
18
19
Key concepts covered:
20
21
- Audio data preprocessing using the Short-Time Fourier Transform (STFT).
22
- Audio data augmentation techniques.
23
- Implementing custom encoders and decoders specialized for audio data.
24
- Defining appropriate loss functions and metrics for audio source separation tasks.
25
26
The model architecture is derived from the TFC_TDF_Net model described in:
27
28
W. Choi, M. Kim, J. Chung, D. Lee, and S. Jung, “Investigating U-Nets with various
29
intermediate blocks for spectrogram-based singing voice separation,” in the 21st
30
International Society for Music Information Retrieval Conference, 2020.
31
32
For reference code, see:
33
[GitHub: ws-choi/ISMIR2020_U_Nets_SVS](https://github.com/ws-choi/ISMIR2020_U_Nets_SVS).
34
35
The data processing and model training routines are partly derived from:
36
[ZFTurbo/Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main).
37
"""
38
39
"""
40
## Setup
41
42
Import and install all the required dependencies.
43
"""
44
45
"""shell
46
pip install -qq audiomentations soundfile ffmpeg-binaries
47
pip install -qq "keras==3.7.0"
48
sudo -n apt-get install -y graphviz >/dev/null 2>&1 # Required for plotting the model
49
"""
50
51
import glob
52
import os
53
54
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
55
56
import random
57
import subprocess
58
import tempfile
59
import typing
60
from os import path
61
62
import audiomentations as aug
63
import ffmpeg
64
import keras
65
import numpy as np
66
import soundfile as sf
67
from IPython import display
68
from keras import callbacks, layers, ops, saving
69
from matplotlib import pyplot as plt
70
71
"""
72
## Configuration
73
74
The following constants define configuration parameters for audio processing
75
and model training, including dataset paths, audio chunk sizes, Short-Time Fourier
76
Transform (STFT) parameters, and training hyperparameters.
77
"""
78
79
# MUSDB18 dataset configuration
80
MUSDB_STREAMS = {"mixture": 0, "drums": 1, "bass": 2, "other": 3, "vocals": 4}
81
TARGET_INSTRUMENTS = {track: MUSDB_STREAMS[track] for track in ("vocals",)}
82
N_INSTRUMENTS = len(TARGET_INSTRUMENTS)
83
SOURCE_INSTRUMENTS = tuple(k for k in MUSDB_STREAMS if k != "mixture")
84
85
# Audio preprocessing parameters for Short-Time Fourier Transform (STFT)
86
N_SUBBANDS = 4 # Number of subbands into which frequencies are split
87
CHUNK_SIZE = 65024 # Number of amplitude samples per audio chunk (~4 seconds)
88
STFT_N_FFT = 2048 # FFT points used in STFT
89
STFT_HOP_LENGTH = 512 # Hop length for STFT
90
91
# Training hyperparameters
92
N_CHANNELS = 64 # Base channel count for the model
93
BATCH_SIZE = 3
94
ACCUMULATION_STEPS = 2
95
EFFECTIVE_BATCH_SIZE = BATCH_SIZE * (ACCUMULATION_STEPS or 1)
96
97
# Paths
98
TMP_DIR = path.expanduser("~/.keras/tmp")
99
DATASET_DIR = path.expanduser("~/.keras/datasets")
100
MODEL_PATH = path.join(TMP_DIR, f"model_{keras.backend.backend()}.keras")
101
CSV_LOG_PATH = path.join(TMP_DIR, f"training_{keras.backend.backend()}.csv")
102
os.makedirs(DATASET_DIR, exist_ok=True)
103
os.makedirs(TMP_DIR, exist_ok=True)
104
105
# Set random seed for reproducibility
106
keras.utils.set_random_seed(21)
107
108
"""
109
## MUSDB18 Dataset
110
111
The MUSDB18 dataset is a standard benchmark for music source separation, containing
112
150 full-length music tracks along with isolated drums, bass, other, and vocals.
113
The dataset is stored in .mp4 format, and each .mp4 file includes multiple audio
114
streams (mixture and individual tracks).
115
116
### Download and Conversion
117
118
The following utility function downloads MUSDB18 and converts its .mp4 files to
119
.wav files for each instrument track, resampled to 16 kHz.
120
"""
121
122
123
def download_musdb18(out_dir=None):
124
"""Download and extract the MUSDB18 dataset, then convert .mp4 files to .wav files.
125
126
MUSDB18 reference:
127
Rafii, Z., Liutkus, A., Stöter, F.-R., Mimilakis, S. I., & Bittner, R. (2017).
128
MUSDB18 - a corpus for music separation (1.0.0) [Data set]. Zenodo.
129
"""
130
ffmpeg.init()
131
from ffmpeg import FFMPEG_PATH
132
133
# Create output directories
134
os.makedirs((base := out_dir or tempfile.mkdtemp()), exist_ok=True)
135
if path.exists((out_dir := path.join(base, "musdb18_wav"))):
136
print("MUSDB18 dataset already downloaded")
137
return out_dir
138
139
# Download and extract the dataset
140
download_dir = keras.utils.get_file(
141
fname="musdb18",
142
origin="https://zenodo.org/records/1117372/files/musdb18.zip",
143
extract=True,
144
)
145
146
# ffmpeg command template: input, stream index, output
147
ffmpeg_args = str(FFMPEG_PATH) + " -v error -i {} -map 0:{} -vn -ar 16000 {}"
148
149
# Convert each mp4 file to multiple .wav files for each track
150
for split in ("train", "test"):
151
songs = os.listdir(path.join(download_dir, split))
152
for i, song in enumerate(songs):
153
if i % 10 == 0:
154
print(f"{split.capitalize()}: {i}/{len(songs)} songs processed")
155
156
mp4_path_orig = path.join(download_dir, split, song)
157
mp4_path = path.join(tempfile.mkdtemp(), split, song.replace(" ", "_"))
158
os.makedirs(path.dirname(mp4_path), exist_ok=True)
159
os.rename(mp4_path_orig, mp4_path)
160
161
wav_dir = path.join(out_dir, split, path.basename(mp4_path).split(".")[0])
162
os.makedirs(wav_dir, exist_ok=True)
163
164
for track in SOURCE_INSTRUMENTS:
165
out_path = path.join(wav_dir, f"{track}.wav")
166
stream_index = MUSDB_STREAMS[track]
167
args = ffmpeg_args.format(mp4_path, stream_index, out_path).split()
168
assert subprocess.run(args).returncode == 0, "ffmpeg conversion failed"
169
return out_dir
170
171
172
# Download and prepare the MUSDB18 dataset
173
songs = download_musdb18(out_dir=DATASET_DIR)
174
175
"""
176
### Custom Dataset
177
178
We define a custom dataset class to generate random audio chunks and their corresponding
179
labels. The dataset does the following:
180
181
1. Selects a random chunk from a random song and instrument.
182
2. Applies optional data augmentations.
183
3. Combines isolated tracks to form new synthetic mixtures.
184
4. Prepares features (mixtures) and labels (vocals) for training.
185
186
This approach allows creating an effectively infinite variety of training examples
187
through randomization and augmentation.
188
"""
189
190
191
class Dataset(keras.utils.PyDataset):
192
def __init__(
193
self,
194
songs,
195
batch_size=BATCH_SIZE,
196
chunk_size=CHUNK_SIZE,
197
batches_per_epoch=1000 * ACCUMULATION_STEPS,
198
augmentation=True,
199
**kwargs,
200
):
201
super().__init__(**kwargs)
202
self.augmentation = augmentation
203
self.vocals_augmentations = [
204
aug.PitchShift(min_semitones=-5, max_semitones=5, p=0.1),
205
aug.SevenBandParametricEQ(-9, 9, p=0.25),
206
aug.TanhDistortion(0.1, 0.7, p=0.1),
207
]
208
self.other_augmentations = [
209
aug.PitchShift(p=0.1),
210
aug.AddGaussianNoise(p=0.1),
211
]
212
self.songs = songs
213
self.sizes = {song: self.get_track_set_size(song) for song in self.songs}
214
self.batch_size = batch_size
215
self.chunk_size = chunk_size
216
self.batches_per_epoch = batches_per_epoch
217
218
def get_track_set_size(self, song: str):
219
"""Return the smallest track length in the given song directory."""
220
sizes = [len(sf.read(p)[0]) for p in glob.glob(path.join(song, "*.wav"))]
221
if max(sizes) != min(sizes):
222
print(f"Warning: {song} has different track lengths")
223
return min(sizes)
224
225
def random_chunk_of_instrument_type(self, instrument: str):
226
"""Extract a random chunk for the specified instrument from a random song."""
227
song, size = random.choice(list(self.sizes.items()))
228
track = path.join(song, f"{instrument}.wav")
229
230
if self.chunk_size <= size:
231
start = np.random.randint(size - self.chunk_size + 1)
232
audio = sf.read(track, self.chunk_size, start, dtype="float32")[0]
233
audio_mono = np.mean(audio, axis=1)
234
else:
235
# If the track is shorter than chunk_size, pad the signal
236
audio_mono = np.mean(sf.read(track, dtype="float32")[0], axis=1)
237
audio_mono = np.pad(audio_mono, ((0, self.chunk_size - size),))
238
239
# If the chunk is almost silent, retry
240
if np.mean(np.abs(audio_mono)) < 0.01:
241
return self.random_chunk_of_instrument_type(instrument)
242
243
return self.data_augmentation(audio_mono, instrument)
244
245
def data_augmentation(self, audio: np.ndarray, instrument: str):
246
"""Apply data augmentation to the audio chunk, if enabled."""
247
248
def coin_flip(x, probability: float, fn: typing.Callable):
249
return fn(x) if random.uniform(0, 1) < probability else x
250
251
if self.augmentation:
252
augmentations = (
253
self.vocals_augmentations
254
if instrument == "vocals"
255
else self.other_augmentations
256
)
257
# Loudness augmentation
258
audio *= np.random.uniform(0.5, 1.5, (len(audio),)).astype("float32")
259
# Random reverse
260
audio = coin_flip(audio, 0.1, lambda x: np.flip(x))
261
# Random polarity inversion
262
audio = coin_flip(audio, 0.5, lambda x: -x)
263
# Apply selected augmentations
264
for aug_ in augmentations:
265
aug_.randomize_parameters(audio, sample_rate=16000)
266
audio = aug_(audio, sample_rate=16000)
267
return audio
268
269
def random_mix_of_tracks(self) -> dict:
270
"""Create a random mix of instruments by summing their individual chunks."""
271
tracks = {}
272
for instrument in SOURCE_INSTRUMENTS:
273
# Start with a single random chunk
274
mixup = [self.random_chunk_of_instrument_type(instrument)]
275
276
# Randomly add more chunks of the same instrument (mixup augmentation)
277
if self.augmentation:
278
for p in (0.2, 0.02):
279
if random.uniform(0, 1) < p:
280
mixup.append(self.random_chunk_of_instrument_type(instrument))
281
282
tracks[instrument] = np.mean(mixup, axis=0, dtype="float32")
283
return tracks
284
285
def __len__(self):
286
return self.batches_per_epoch
287
288
def __getitem__(self, idx):
289
# Generate a batch of random mixtures
290
batch = [self.random_mix_of_tracks() for _ in range(self.batch_size)]
291
292
# Features: sum of all tracks
293
batch_x = ops.sum(
294
np.array([list(track_set.values()) for track_set in batch]), axis=1
295
)
296
297
# Labels: isolated target instruments (e.g., vocals)
298
batch_y = np.array(
299
[[track_set[t] for t in TARGET_INSTRUMENTS] for track_set in batch]
300
)
301
302
return batch_x, ops.convert_to_tensor(batch_y)
303
304
305
# Create train and validation datasets
306
train_ds = Dataset(glob.glob(path.join(songs, "train", "*")))
307
val_ds = Dataset(
308
glob.glob(path.join(songs, "test", "*")),
309
batches_per_epoch=int(0.1 * train_ds.batches_per_epoch),
310
augmentation=False,
311
)
312
313
"""
314
### Visualize a Sample
315
316
Let's visualize a random mixed audio chunk and its corresponding isolated vocals.
317
This helps to understand the nature of the preprocessed input data.
318
"""
319
320
321
def visualize_audio_np(audio: np.ndarray, rate=16000, name="mixup"):
322
"""Plot and display an audio waveform and also produce an Audio widget."""
323
plt.figure(figsize=(10, 6))
324
plt.plot(audio)
325
plt.title(f"Waveform: {name}")
326
plt.xlim(0, len(audio))
327
plt.ylabel("Amplitude")
328
plt.show()
329
# plt.savefig(f"tmp/{name}.png")
330
331
# Normalize and display audio
332
audio_norm = (audio - np.min(audio)) / (np.max(audio) - np.min(audio) + 1e-8)
333
audio_norm = (audio_norm * 2 - 1) * 0.6
334
display.display(display.Audio(audio_norm, rate=rate))
335
# sf.write(f"tmp/{name}.wav", audio_norm, rate)
336
337
338
sample_batch_x, sample_batch_y = val_ds[None] # Random batch
339
visualize_audio_np(ops.convert_to_numpy(sample_batch_x[0]))
340
visualize_audio_np(ops.convert_to_numpy(sample_batch_y[0, 0]), name="vocals")
341
342
"""
343
## Model
344
345
### Preprocessing
346
347
The model operates on STFT representations rather than raw audio. We define a
348
preprocessing model to compute STFT and a corresponding inverse transform (iSTFT).
349
"""
350
351
352
def stft(inputs, fft_size=STFT_N_FFT, sequence_stride=STFT_HOP_LENGTH):
353
"""Compute the STFT for the input audio and return the real and imaginary parts."""
354
real_x, imag_x = ops.stft(inputs, fft_size, sequence_stride, fft_size)
355
real_x, imag_x = ops.expand_dims(real_x, -1), ops.expand_dims(imag_x, -1)
356
x = ops.concatenate((real_x, imag_x), axis=-1)
357
358
# Drop last freq sample for convenience
359
return ops.split(x, [x.shape[2] - 1], axis=2)[0]
360
361
362
def inverse_stft(inputs, fft_size=STFT_N_FFT, sequence_stride=STFT_HOP_LENGTH):
363
"""Compute the inverse STFT for the given STFT input."""
364
x = inputs
365
366
# Pad back dropped freq sample if using torch backend
367
if keras.backend.backend() == "torch":
368
x = ops.pad(x, ((0, 0), (0, 0), (0, 1), (0, 0)))
369
370
real_x, imag_x = ops.split(x, 2, axis=-1)
371
real_x = ops.squeeze(real_x, axis=-1)
372
imag_x = ops.squeeze(imag_x, axis=-1)
373
374
return ops.istft((real_x, imag_x), fft_size, sequence_stride, fft_size)
375
376
377
"""
378
### Model Architecture
379
380
The model uses a custom encoder-decoder architecture with Time-Frequency Convolution
381
(TFC) and Time-Distributed Fully Connected (TDF) blocks. They are grouped into a
382
`TimeFrequencyTransformBlock`, i.e. "TFC_TDF" in the original paper by Choi et al.
383
384
We then define an encoder-decoder network with multiple scales. Each encoder scale
385
applies TFC_TDF blocks followed by downsampling, while decoder scales apply TFC_TDF
386
blocks over the concatenation of upsampled features and associated encoder outputs.
387
"""
388
389
390
@saving.register_keras_serializable()
391
class TimeDistributedDenseBlock(layers.Layer):
392
"""Time-Distributed Fully Connected layer block.
393
394
Applies frequency-wise dense transformations across time frames with instance
395
normalization and GELU activation.
396
"""
397
398
def __init__(self, bottleneck_factor, fft_dim, **kwargs):
399
super().__init__(**kwargs)
400
self.fft_dim = fft_dim
401
self.hidden_dim = fft_dim // bottleneck_factor
402
403
def build(self, *_):
404
self.group_norm_1 = layers.GroupNormalization(groups=-1)
405
self.group_norm_2 = layers.GroupNormalization(groups=-1)
406
self.dense_1 = layers.Dense(self.hidden_dim, use_bias=False)
407
self.dense_2 = layers.Dense(self.fft_dim, use_bias=False)
408
409
def call(self, x):
410
# Apply normalization and dense layers frequency-wise
411
x = ops.gelu(self.group_norm_1(x))
412
x = ops.swapaxes(x, -1, -2)
413
x = self.dense_1(x)
414
415
x = ops.gelu(self.group_norm_2(ops.swapaxes(x, -1, -2)))
416
x = ops.swapaxes(x, -1, -2)
417
x = self.dense_2(x)
418
return ops.swapaxes(x, -1, -2)
419
420
421
@saving.register_keras_serializable()
422
class TimeFrequencyConvolution(layers.Layer):
423
"""Time-Frequency Convolutional layer.
424
425
Applies a 2D convolution over time-frequency representations and applies instance
426
normalization and GELU activation.
427
"""
428
429
def __init__(self, channels, **kwargs):
430
super().__init__(**kwargs)
431
self.channels = channels
432
433
def build(self, *_):
434
self.group_norm = layers.GroupNormalization(groups=-1)
435
self.conv = layers.Conv2D(self.channels, 3, padding="same", use_bias=False)
436
437
def call(self, x):
438
return self.conv(ops.gelu(self.group_norm(x)))
439
440
441
@saving.register_keras_serializable()
442
class TimeFrequencyTransformBlock(layers.Layer):
443
"""Implements TFC_TDF block for encoder-decoder architecture.
444
445
Repeatedly apply Time-Frequency Convolution and Time-Distributed Dense blocks as
446
many times as specified by the `length` parameter.
447
"""
448
449
def __init__(
450
self, channels, length, fft_dim, bottleneck_factor, in_channels=None, **kwargs
451
):
452
super().__init__(**kwargs)
453
self.channels = channels
454
self.length = length
455
self.fft_dim = fft_dim
456
self.bottleneck_factor = bottleneck_factor
457
self.in_channels = in_channels or channels
458
459
def build(self, *_):
460
self.blocks = []
461
# Add blocks in a flat list to avoid nested structures
462
for i in range(self.length):
463
in_channels = self.channels if i > 0 else self.in_channels
464
self.blocks.append(TimeFrequencyConvolution(in_channels))
465
self.blocks.append(
466
TimeDistributedDenseBlock(self.bottleneck_factor, self.fft_dim)
467
)
468
self.blocks.append(TimeFrequencyConvolution(self.channels))
469
# Residual connection
470
self.blocks.append(layers.Conv2D(self.channels, 1, 1, use_bias=False))
471
472
def call(self, inputs):
473
x = inputs
474
# Each block consists of 4 layers:
475
# 1. Time-Frequency Convolution
476
# 2. Time-Distributed Dense
477
# 3. Time-Frequency Convolution
478
# 4. Residual connection
479
for i in range(0, len(self.blocks), 4):
480
tfc_1 = self.blocks[i](x)
481
tdf = self.blocks[i + 1](x)
482
tfc_2 = self.blocks[i + 2](tfc_1 + tdf)
483
x = tfc_2 + self.blocks[i + 3](x) # Residual connection
484
return x
485
486
487
@saving.register_keras_serializable()
488
class Downscale(layers.Layer):
489
"""Downscale time-frequency dimensions using a convolution."""
490
491
conv_cls = layers.Conv2D
492
493
def __init__(self, channels, scale, **kwargs):
494
super().__init__(**kwargs)
495
self.channels = channels
496
self.scale = scale
497
498
def build(self, *_):
499
self.conv = self.conv_cls(self.channels, self.scale, self.scale, use_bias=False)
500
self.norm = layers.GroupNormalization(groups=-1)
501
502
def call(self, inputs):
503
return self.norm(ops.gelu(self.conv(inputs)))
504
505
506
@saving.register_keras_serializable()
507
class Upscale(Downscale):
508
"""Upscale time-frequency dimensions using a transposed convolution."""
509
510
conv_cls = layers.Conv2DTranspose
511
512
513
def build_model(
514
inputs,
515
n_instruments=N_INSTRUMENTS,
516
n_subbands=N_SUBBANDS,
517
channels=N_CHANNELS,
518
fft_dim=(STFT_N_FFT // 2) // N_SUBBANDS,
519
n_scales=4,
520
scale=(2, 2),
521
block_size=2,
522
growth=128,
523
bottleneck_factor=2,
524
**kwargs,
525
):
526
"""Build the TFC_TDF encoder-decoder model for source separation."""
527
# Compute STFT
528
x = stft(inputs)
529
530
# Split mixture into subbands as separate channels
531
mix = ops.reshape(x, (-1, x.shape[1], x.shape[2] // n_subbands, 2 * n_subbands))
532
first_conv_out = layers.Conv2D(channels, 1, 1, use_bias=False)(mix)
533
x = first_conv_out
534
535
# Encoder path
536
encoder_outs = []
537
for _ in range(n_scales):
538
x = TimeFrequencyTransformBlock(
539
channels, block_size, fft_dim, bottleneck_factor
540
)(x)
541
encoder_outs.append(x)
542
fft_dim, channels = fft_dim // scale[0], channels + growth
543
x = Downscale(channels, scale)(x)
544
545
# Bottleneck
546
x = TimeFrequencyTransformBlock(channels, block_size, fft_dim, bottleneck_factor)(x)
547
548
# Decoder path
549
for _ in range(n_scales):
550
fft_dim, channels = fft_dim * scale[0], channels - growth
551
x = ops.concatenate([Upscale(channels, scale)(x), encoder_outs.pop()], axis=-1)
552
x = TimeFrequencyTransformBlock(
553
channels, block_size, fft_dim, bottleneck_factor, in_channels=x.shape[-1]
554
)(x)
555
556
# Residual connection and final convolutions
557
x = ops.concatenate([mix, x * first_conv_out], axis=-1)
558
x = layers.Conv2D(channels, 1, 1, use_bias=False, activation="gelu")(x)
559
x = layers.Conv2D(n_instruments * n_subbands * 2, 1, 1, use_bias=False)(x)
560
561
# Reshape back to instrument-wise STFT
562
x = ops.reshape(x, (-1, x.shape[1], x.shape[2] * n_subbands, n_instruments, 2))
563
x = ops.transpose(x, (0, 3, 1, 2, 4))
564
x = ops.reshape(x, (-1, n_instruments, x.shape[2], x.shape[3] * 2))
565
566
return keras.Model(inputs=inputs, outputs=x, **kwargs)
567
568
569
"""
570
## Loss and Metrics
571
572
We define:
573
574
- `spectral_loss`: Mean absolute error in STFT domain.
575
- `sdr`: Signal-to-Distortion Ratio, a common source separation metric.
576
"""
577
578
579
def prediction_to_wave(x, n_instruments=N_INSTRUMENTS):
580
"""Convert STFT predictions back to waveform."""
581
x = ops.reshape(x, (-1, x.shape[2], x.shape[3] // 2, 2))
582
x = inverse_stft(x)
583
return ops.reshape(x, (-1, n_instruments, x.shape[1]))
584
585
586
def target_to_stft(y):
587
"""Convert target waveforms to their STFT representations."""
588
y = ops.reshape(y, (-1, CHUNK_SIZE))
589
y_real, y_imag = ops.stft(y, STFT_N_FFT, STFT_HOP_LENGTH, STFT_N_FFT)
590
y_real, y_imag = y_real[..., :-1], y_imag[..., :-1]
591
y = ops.stack([y_real, y_imag], axis=-1)
592
return ops.reshape(y, (-1, N_INSTRUMENTS, y.shape[1], y.shape[2] * 2))
593
594
595
@saving.register_keras_serializable()
596
def sdr(y_true, y_pred):
597
"""Signal-to-Distortion Ratio metric."""
598
y_pred = prediction_to_wave(y_pred)
599
# Add epsilon for numerical stability
600
num = ops.sum(ops.square(y_true), axis=-1) + 1e-8
601
den = ops.sum(ops.square(y_true - y_pred), axis=-1) + 1e-8
602
return 10 * ops.log10(num / den)
603
604
605
@saving.register_keras_serializable()
606
def spectral_loss(y_true, y_pred):
607
"""Mean absolute error in the STFT domain."""
608
y_true = target_to_stft(y_true)
609
return ops.mean(ops.absolute(y_true - y_pred))
610
611
612
"""
613
## Training
614
615
### Visualize Model Architecture
616
"""
617
618
# Load or create the model
619
if path.exists(MODEL_PATH):
620
model = saving.load_model(MODEL_PATH)
621
else:
622
model = build_model(keras.Input(sample_batch_x.shape[1:]), name="tfc_tdf_net")
623
624
# Display the model architecture
625
model.summary()
626
img = keras.utils.plot_model(model, path.join(TMP_DIR, "model.png"), show_shapes=True)
627
display.display(img)
628
629
"""
630
### Compile and Train the Model
631
"""
632
633
# Compile the model
634
optimizer = keras.optimizers.Adam(5e-05, gradient_accumulation_steps=ACCUMULATION_STEPS)
635
model.compile(optimizer=optimizer, loss=spectral_loss, metrics=[sdr])
636
637
# Define callbacks
638
cbs = [
639
callbacks.ModelCheckpoint(MODEL_PATH, "val_sdr", save_best_only=True, mode="max"),
640
callbacks.ReduceLROnPlateau(factor=0.95, patience=2),
641
callbacks.CSVLogger(CSV_LOG_PATH),
642
]
643
644
if not path.exists(MODEL_PATH):
645
model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=cbs, shuffle=False)
646
else:
647
# Demonstration of a single epoch of training when model already exists
648
model.fit(train_ds, validation_data=val_ds, epochs=1, shuffle=False, verbose=2)
649
650
"""
651
## Evaluation
652
653
Evaluate the model on the validation dataset and visualize predicted vocals.
654
"""
655
656
model.evaluate(val_ds, verbose=2)
657
y_pred = model.predict(sample_batch_x, verbose=2)
658
y_pred = prediction_to_wave(y_pred)
659
visualize_audio_np(ops.convert_to_numpy(y_pred[0, 0]), name="vocals_pred")
660
661
"""
662
## Conclusion
663
664
We built and trained a vocal track separation model using an encoder-decoder
665
architecture with custom blocks applied to the MUSDB18 dataset. We demonstrated
666
STFT-based preprocessing, data augmentation, and a source separation metric (SDR).
667
668
**Next steps:**
669
670
- Train for more epochs and refine hyperparameters.
671
- Separate multiple instruments simultaneously.
672
- Enhance the model to handle instruments not present in the mixture.
673
"""
674
675