Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/speaker_recognition_using_cnn.py
3507 views
1
"""
2
Title: Speaker Recognition
3
Author: [Fadi Badine](https://twitter.com/fadibadine)
4
Date created: 14/06/2020
5
Last modified: 19/07/2023
6
Description: Classify speakers using Fast Fourier Transform (FFT) and a 1D Convnet.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Fadi Badine](https://twitter.com/fadibadine)
9
"""
10
11
"""
12
## Introduction
13
14
This example demonstrates how to create a model to classify speakers from the
15
frequency domain representation of speech recordings, obtained via Fast Fourier
16
Transform (FFT).
17
18
It shows the following:
19
20
- How to use `tf.data` to load, preprocess and feed audio streams into a model
21
- How to create a 1D convolutional network with residual
22
connections for audio classification.
23
24
Our process:
25
26
- We prepare a dataset of speech samples from different speakers, with the speaker as label.
27
- We add background noise to these samples to augment our data.
28
- We take the FFT of these samples.
29
- We train a 1D convnet to predict the correct speaker given a noisy FFT speech sample.
30
31
Note:
32
33
- This example should be run with TensorFlow 2.3 or higher, or `tf-nightly`.
34
- The noise samples in the dataset need to be resampled to a sampling rate of 16000 Hz
35
before using the code in this example. In order to do this, you will need to have
36
installed `ffmpg`.
37
"""
38
39
"""
40
## Setup
41
"""
42
43
import os
44
45
os.environ["KERAS_BACKEND"] = "tensorflow"
46
47
import shutil
48
import numpy as np
49
50
import tensorflow as tf
51
import keras
52
53
from pathlib import Path
54
from IPython.display import display, Audio
55
56
# Get the data from https://www.kaggle.com/kongaevans/speaker-recognition-dataset/
57
# and save it to ./speaker-recognition-dataset.zip
58
# then unzip it to ./16000_pcm_speeches
59
"""shell
60
kaggle datasets download -d kongaevans/speaker-recognition-dataset
61
unzip -qq speaker-recognition-dataset.zip
62
"""
63
64
DATASET_ROOT = "16000_pcm_speeches"
65
66
# The folders in which we will put the audio samples and the noise samples
67
AUDIO_SUBFOLDER = "audio"
68
NOISE_SUBFOLDER = "noise"
69
70
DATASET_AUDIO_PATH = os.path.join(DATASET_ROOT, AUDIO_SUBFOLDER)
71
DATASET_NOISE_PATH = os.path.join(DATASET_ROOT, NOISE_SUBFOLDER)
72
73
# Percentage of samples to use for validation
74
VALID_SPLIT = 0.1
75
76
# Seed to use when shuffling the dataset and the noise
77
SHUFFLE_SEED = 43
78
79
# The sampling rate to use.
80
# This is the one used in all the audio samples.
81
# We will resample all the noise to this sampling rate.
82
# This will also be the output size of the audio wave samples
83
# (since all samples are of 1 second long)
84
SAMPLING_RATE = 16000
85
86
# The factor to multiply the noise with according to:
87
# noisy_sample = sample + noise * prop * scale
88
# where prop = sample_amplitude / noise_amplitude
89
SCALE = 0.5
90
91
BATCH_SIZE = 128
92
EPOCHS = 1 # For a real training run, use EPOCHS = 100
93
94
95
"""
96
## Data preparation
97
98
The dataset is composed of 7 folders, divided into 2 groups:
99
100
- Speech samples, with 5 folders for 5 different speakers. Each folder contains
101
1500 audio files, each 1 second long and sampled at 16000 Hz.
102
- Background noise samples, with 2 folders and a total of 6 files. These files
103
are longer than 1 second (and originally not sampled at 16000 Hz, but we will resample them to 16000 Hz).
104
We will use those 6 files to create 354 1-second-long noise samples to be used for training.
105
106
Let's sort these 2 categories into 2 folders:
107
108
- An `audio` folder which will contain all the per-speaker speech sample folders
109
- A `noise` folder which will contain all the noise samples
110
"""
111
112
"""
113
Before sorting the audio and noise categories into 2 folders,
114
we have the following directory structure:
115
116
```
117
main_directory/
118
...speaker_a/
119
...speaker_b/
120
...speaker_c/
121
...speaker_d/
122
...speaker_e/
123
...other/
124
..._background_noise_/
125
```
126
127
After sorting, we end up with the following structure:
128
129
```
130
main_directory/
131
...audio/
132
......speaker_a/
133
......speaker_b/
134
......speaker_c/
135
......speaker_d/
136
......speaker_e/
137
...noise/
138
......other/
139
......_background_noise_/
140
```
141
"""
142
143
for folder in os.listdir(DATASET_ROOT):
144
if os.path.isdir(os.path.join(DATASET_ROOT, folder)):
145
if folder in [AUDIO_SUBFOLDER, NOISE_SUBFOLDER]:
146
# If folder is `audio` or `noise`, do nothing
147
continue
148
elif folder in ["other", "_background_noise_"]:
149
# If folder is one of the folders that contains noise samples,
150
# move it to the `noise` folder
151
shutil.move(
152
os.path.join(DATASET_ROOT, folder),
153
os.path.join(DATASET_NOISE_PATH, folder),
154
)
155
else:
156
# Otherwise, it should be a speaker folder, then move it to
157
# `audio` folder
158
shutil.move(
159
os.path.join(DATASET_ROOT, folder),
160
os.path.join(DATASET_AUDIO_PATH, folder),
161
)
162
163
"""
164
## Noise preparation
165
166
In this section:
167
168
- We load all noise samples (which should have been resampled to 16000)
169
- We split those noise samples to chunks of 16000 samples which
170
correspond to 1 second duration each
171
"""
172
173
# Get the list of all noise files
174
noise_paths = []
175
for subdir in os.listdir(DATASET_NOISE_PATH):
176
subdir_path = Path(DATASET_NOISE_PATH) / subdir
177
if os.path.isdir(subdir_path):
178
noise_paths += [
179
os.path.join(subdir_path, filepath)
180
for filepath in os.listdir(subdir_path)
181
if filepath.endswith(".wav")
182
]
183
if not noise_paths:
184
raise RuntimeError(f"Could not find any files at {DATASET_NOISE_PATH}")
185
print(
186
"Found {} files belonging to {} directories".format(
187
len(noise_paths), len(os.listdir(DATASET_NOISE_PATH))
188
)
189
)
190
191
"""
192
Resample all noise samples to 16000 Hz
193
"""
194
195
command = (
196
"for dir in `ls -1 " + DATASET_NOISE_PATH + "`; do "
197
"for file in `ls -1 " + DATASET_NOISE_PATH + "/$dir/*.wav`; do "
198
"sample_rate=`ffprobe -hide_banner -loglevel panic -show_streams "
199
"$file | grep sample_rate | cut -f2 -d=`; "
200
"if [ $sample_rate -ne 16000 ]; then "
201
"ffmpeg -hide_banner -loglevel panic -y "
202
"-i $file -ar 16000 temp.wav; "
203
"mv temp.wav $file; "
204
"fi; done; done"
205
)
206
os.system(command)
207
208
209
# Split noise into chunks of 16,000 steps each
210
def load_noise_sample(path):
211
sample, sampling_rate = tf.audio.decode_wav(
212
tf.io.read_file(path), desired_channels=1
213
)
214
if sampling_rate == SAMPLING_RATE:
215
# Number of slices of 16000 each that can be generated from the noise sample
216
slices = int(sample.shape[0] / SAMPLING_RATE)
217
sample = tf.split(sample[: slices * SAMPLING_RATE], slices)
218
return sample
219
else:
220
print("Sampling rate for {} is incorrect. Ignoring it".format(path))
221
return None
222
223
224
noises = []
225
for path in noise_paths:
226
sample = load_noise_sample(path)
227
if sample:
228
noises.extend(sample)
229
noises = tf.stack(noises)
230
231
print(
232
"{} noise files were split into {} noise samples where each is {} sec. long".format(
233
len(noise_paths), noises.shape[0], noises.shape[1] // SAMPLING_RATE
234
)
235
)
236
237
"""
238
## Dataset generation
239
"""
240
241
242
def paths_and_labels_to_dataset(audio_paths, labels):
243
"""Constructs a dataset of audios and labels."""
244
path_ds = tf.data.Dataset.from_tensor_slices(audio_paths)
245
audio_ds = path_ds.map(
246
lambda x: path_to_audio(x), num_parallel_calls=tf.data.AUTOTUNE
247
)
248
label_ds = tf.data.Dataset.from_tensor_slices(labels)
249
return tf.data.Dataset.zip((audio_ds, label_ds))
250
251
252
def path_to_audio(path):
253
"""Reads and decodes an audio file."""
254
audio = tf.io.read_file(path)
255
audio, _ = tf.audio.decode_wav(audio, 1, SAMPLING_RATE)
256
return audio
257
258
259
def add_noise(audio, noises=None, scale=0.5):
260
if noises is not None:
261
# Create a random tensor of the same size as audio ranging from
262
# 0 to the number of noise stream samples that we have.
263
tf_rnd = tf.random.uniform(
264
(tf.shape(audio)[0],), 0, noises.shape[0], dtype=tf.int32
265
)
266
noise = tf.gather(noises, tf_rnd, axis=0)
267
268
# Get the amplitude proportion between the audio and the noise
269
prop = tf.math.reduce_max(audio, axis=1) / tf.math.reduce_max(noise, axis=1)
270
prop = tf.repeat(tf.expand_dims(prop, axis=1), tf.shape(audio)[1], axis=1)
271
272
# Adding the rescaled noise to audio
273
audio = audio + noise * prop * scale
274
275
return audio
276
277
278
def audio_to_fft(audio):
279
# Since tf.signal.fft applies FFT on the innermost dimension,
280
# we need to squeeze the dimensions and then expand them again
281
# after FFT
282
audio = tf.squeeze(audio, axis=-1)
283
fft = tf.signal.fft(
284
tf.cast(tf.complex(real=audio, imag=tf.zeros_like(audio)), tf.complex64)
285
)
286
fft = tf.expand_dims(fft, axis=-1)
287
288
# Return the absolute value of the first half of the FFT
289
# which represents the positive frequencies
290
return tf.math.abs(fft[:, : (audio.shape[1] // 2), :])
291
292
293
# Get the list of audio file paths along with their corresponding labels
294
295
class_names = os.listdir(DATASET_AUDIO_PATH)
296
print(
297
"Our class names: {}".format(
298
class_names,
299
)
300
)
301
302
audio_paths = []
303
labels = []
304
for label, name in enumerate(class_names):
305
print(
306
"Processing speaker {}".format(
307
name,
308
)
309
)
310
dir_path = Path(DATASET_AUDIO_PATH) / name
311
speaker_sample_paths = [
312
os.path.join(dir_path, filepath)
313
for filepath in os.listdir(dir_path)
314
if filepath.endswith(".wav")
315
]
316
audio_paths += speaker_sample_paths
317
labels += [label] * len(speaker_sample_paths)
318
319
print(
320
"Found {} files belonging to {} classes.".format(len(audio_paths), len(class_names))
321
)
322
323
# Shuffle
324
rng = np.random.RandomState(SHUFFLE_SEED)
325
rng.shuffle(audio_paths)
326
rng = np.random.RandomState(SHUFFLE_SEED)
327
rng.shuffle(labels)
328
329
# Split into training and validation
330
num_val_samples = int(VALID_SPLIT * len(audio_paths))
331
print("Using {} files for training.".format(len(audio_paths) - num_val_samples))
332
train_audio_paths = audio_paths[:-num_val_samples]
333
train_labels = labels[:-num_val_samples]
334
335
print("Using {} files for validation.".format(num_val_samples))
336
valid_audio_paths = audio_paths[-num_val_samples:]
337
valid_labels = labels[-num_val_samples:]
338
339
# Create 2 datasets, one for training and the other for validation
340
train_ds = paths_and_labels_to_dataset(train_audio_paths, train_labels)
341
train_ds = train_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch(
342
BATCH_SIZE
343
)
344
345
valid_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels)
346
valid_ds = valid_ds.shuffle(buffer_size=32 * 8, seed=SHUFFLE_SEED).batch(32)
347
348
349
# Add noise to the training set
350
train_ds = train_ds.map(
351
lambda x, y: (add_noise(x, noises, scale=SCALE), y),
352
num_parallel_calls=tf.data.AUTOTUNE,
353
)
354
355
# Transform audio wave to the frequency domain using `audio_to_fft`
356
train_ds = train_ds.map(
357
lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE
358
)
359
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
360
361
valid_ds = valid_ds.map(
362
lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE
363
)
364
valid_ds = valid_ds.prefetch(tf.data.AUTOTUNE)
365
366
"""
367
## Model Definition
368
"""
369
370
371
def residual_block(x, filters, conv_num=3, activation="relu"):
372
# Shortcut
373
s = keras.layers.Conv1D(filters, 1, padding="same")(x)
374
for i in range(conv_num - 1):
375
x = keras.layers.Conv1D(filters, 3, padding="same")(x)
376
x = keras.layers.Activation(activation)(x)
377
x = keras.layers.Conv1D(filters, 3, padding="same")(x)
378
x = keras.layers.Add()([x, s])
379
x = keras.layers.Activation(activation)(x)
380
return keras.layers.MaxPool1D(pool_size=2, strides=2)(x)
381
382
383
def build_model(input_shape, num_classes):
384
inputs = keras.layers.Input(shape=input_shape, name="input")
385
386
x = residual_block(inputs, 16, 2)
387
x = residual_block(x, 32, 2)
388
x = residual_block(x, 64, 3)
389
x = residual_block(x, 128, 3)
390
x = residual_block(x, 128, 3)
391
392
x = keras.layers.AveragePooling1D(pool_size=3, strides=3)(x)
393
x = keras.layers.Flatten()(x)
394
x = keras.layers.Dense(256, activation="relu")(x)
395
x = keras.layers.Dense(128, activation="relu")(x)
396
397
outputs = keras.layers.Dense(num_classes, activation="softmax", name="output")(x)
398
399
return keras.models.Model(inputs=inputs, outputs=outputs)
400
401
402
model = build_model((SAMPLING_RATE // 2, 1), len(class_names))
403
404
model.summary()
405
406
# Compile the model using Adam's default learning rate
407
model.compile(
408
optimizer="Adam",
409
loss="sparse_categorical_crossentropy",
410
metrics=["accuracy"],
411
)
412
413
# Add callbacks:
414
# 'EarlyStopping' to stop training when the model is not enhancing anymore
415
# 'ModelCheckPoint' to always keep the model that has the best val_accuracy
416
model_save_filename = "model.keras"
417
418
earlystopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
419
mdlcheckpoint_cb = keras.callbacks.ModelCheckpoint(
420
model_save_filename, monitor="val_accuracy", save_best_only=True
421
)
422
423
"""
424
## Training
425
"""
426
427
history = model.fit(
428
train_ds,
429
epochs=EPOCHS,
430
validation_data=valid_ds,
431
callbacks=[earlystopping_cb, mdlcheckpoint_cb],
432
)
433
434
"""
435
## Evaluation
436
"""
437
438
print(model.evaluate(valid_ds))
439
440
"""
441
We get ~ 98% validation accuracy.
442
"""
443
444
"""
445
## Demonstration
446
447
Let's take some samples and:
448
449
- Predict the speaker
450
- Compare the prediction with the real speaker
451
- Listen to the audio to see that despite the samples being noisy,
452
the model is still pretty accurate
453
"""
454
455
SAMPLES_TO_DISPLAY = 10
456
457
test_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels)
458
test_ds = test_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch(
459
BATCH_SIZE
460
)
461
462
test_ds = test_ds.map(
463
lambda x, y: (add_noise(x, noises, scale=SCALE), y),
464
num_parallel_calls=tf.data.AUTOTUNE,
465
)
466
467
for audios, labels in test_ds.take(1):
468
# Get the signal FFT
469
ffts = audio_to_fft(audios)
470
# Predict
471
y_pred = model.predict(ffts)
472
# Take random samples
473
rnd = np.random.randint(0, BATCH_SIZE, SAMPLES_TO_DISPLAY)
474
audios = audios.numpy()[rnd, :, :]
475
labels = labels.numpy()[rnd]
476
y_pred = np.argmax(y_pred, axis=-1)[rnd]
477
478
for index in range(SAMPLES_TO_DISPLAY):
479
# For every sample, print the true and predicted label
480
# as well as run the voice with the noise
481
print(
482
"Speaker:\33{} {}\33[0m\tPredicted:\33{} {}\33[0m".format(
483
"[92m" if labels[index] == y_pred[index] else "[91m",
484
class_names[labels[index]],
485
"[92m" if labels[index] == y_pred[index] else "[91m",
486
class_names[y_pred[index]],
487
)
488
)
489
display(Audio(audios[index, :, :].squeeze(), rate=SAMPLING_RATE))
490
491