Path: blob/master/examples/audio/speaker_recognition_using_cnn.py
3507 views
"""1Title: Speaker Recognition2Author: [Fadi Badine](https://twitter.com/fadibadine)3Date created: 14/06/20204Last modified: 19/07/20235Description: Classify speakers using Fast Fourier Transform (FFT) and a 1D Convnet.6Accelerator: GPU7Converted to Keras 3 by: [Fadi Badine](https://twitter.com/fadibadine)8"""910"""11## Introduction1213This example demonstrates how to create a model to classify speakers from the14frequency domain representation of speech recordings, obtained via Fast Fourier15Transform (FFT).1617It shows the following:1819- How to use `tf.data` to load, preprocess and feed audio streams into a model20- How to create a 1D convolutional network with residual21connections for audio classification.2223Our process:2425- We prepare a dataset of speech samples from different speakers, with the speaker as label.26- We add background noise to these samples to augment our data.27- We take the FFT of these samples.28- We train a 1D convnet to predict the correct speaker given a noisy FFT speech sample.2930Note:3132- This example should be run with TensorFlow 2.3 or higher, or `tf-nightly`.33- The noise samples in the dataset need to be resampled to a sampling rate of 16000 Hz34before using the code in this example. In order to do this, you will need to have35installed `ffmpg`.36"""3738"""39## Setup40"""4142import os4344os.environ["KERAS_BACKEND"] = "tensorflow"4546import shutil47import numpy as np4849import tensorflow as tf50import keras5152from pathlib import Path53from IPython.display import display, Audio5455# Get the data from https://www.kaggle.com/kongaevans/speaker-recognition-dataset/56# and save it to ./speaker-recognition-dataset.zip57# then unzip it to ./16000_pcm_speeches58"""shell59kaggle datasets download -d kongaevans/speaker-recognition-dataset60unzip -qq speaker-recognition-dataset.zip61"""6263DATASET_ROOT = "16000_pcm_speeches"6465# The folders in which we will put the audio samples and the noise samples66AUDIO_SUBFOLDER = "audio"67NOISE_SUBFOLDER = "noise"6869DATASET_AUDIO_PATH = os.path.join(DATASET_ROOT, AUDIO_SUBFOLDER)70DATASET_NOISE_PATH = os.path.join(DATASET_ROOT, NOISE_SUBFOLDER)7172# Percentage of samples to use for validation73VALID_SPLIT = 0.17475# Seed to use when shuffling the dataset and the noise76SHUFFLE_SEED = 437778# The sampling rate to use.79# This is the one used in all the audio samples.80# We will resample all the noise to this sampling rate.81# This will also be the output size of the audio wave samples82# (since all samples are of 1 second long)83SAMPLING_RATE = 160008485# The factor to multiply the noise with according to:86# noisy_sample = sample + noise * prop * scale87# where prop = sample_amplitude / noise_amplitude88SCALE = 0.58990BATCH_SIZE = 12891EPOCHS = 1 # For a real training run, use EPOCHS = 100929394"""95## Data preparation9697The dataset is composed of 7 folders, divided into 2 groups:9899- Speech samples, with 5 folders for 5 different speakers. Each folder contains1001500 audio files, each 1 second long and sampled at 16000 Hz.101- Background noise samples, with 2 folders and a total of 6 files. These files102are longer than 1 second (and originally not sampled at 16000 Hz, but we will resample them to 16000 Hz).103We will use those 6 files to create 354 1-second-long noise samples to be used for training.104105Let's sort these 2 categories into 2 folders:106107- An `audio` folder which will contain all the per-speaker speech sample folders108- A `noise` folder which will contain all the noise samples109"""110111"""112Before sorting the audio and noise categories into 2 folders,113we have the following directory structure:114115```116main_directory/117...speaker_a/118...speaker_b/119...speaker_c/120...speaker_d/121...speaker_e/122...other/123..._background_noise_/124```125126After sorting, we end up with the following structure:127128```129main_directory/130...audio/131......speaker_a/132......speaker_b/133......speaker_c/134......speaker_d/135......speaker_e/136...noise/137......other/138......_background_noise_/139```140"""141142for folder in os.listdir(DATASET_ROOT):143if os.path.isdir(os.path.join(DATASET_ROOT, folder)):144if folder in [AUDIO_SUBFOLDER, NOISE_SUBFOLDER]:145# If folder is `audio` or `noise`, do nothing146continue147elif folder in ["other", "_background_noise_"]:148# If folder is one of the folders that contains noise samples,149# move it to the `noise` folder150shutil.move(151os.path.join(DATASET_ROOT, folder),152os.path.join(DATASET_NOISE_PATH, folder),153)154else:155# Otherwise, it should be a speaker folder, then move it to156# `audio` folder157shutil.move(158os.path.join(DATASET_ROOT, folder),159os.path.join(DATASET_AUDIO_PATH, folder),160)161162"""163## Noise preparation164165In this section:166167- We load all noise samples (which should have been resampled to 16000)168- We split those noise samples to chunks of 16000 samples which169correspond to 1 second duration each170"""171172# Get the list of all noise files173noise_paths = []174for subdir in os.listdir(DATASET_NOISE_PATH):175subdir_path = Path(DATASET_NOISE_PATH) / subdir176if os.path.isdir(subdir_path):177noise_paths += [178os.path.join(subdir_path, filepath)179for filepath in os.listdir(subdir_path)180if filepath.endswith(".wav")181]182if not noise_paths:183raise RuntimeError(f"Could not find any files at {DATASET_NOISE_PATH}")184print(185"Found {} files belonging to {} directories".format(186len(noise_paths), len(os.listdir(DATASET_NOISE_PATH))187)188)189190"""191Resample all noise samples to 16000 Hz192"""193194command = (195"for dir in `ls -1 " + DATASET_NOISE_PATH + "`; do "196"for file in `ls -1 " + DATASET_NOISE_PATH + "/$dir/*.wav`; do "197"sample_rate=`ffprobe -hide_banner -loglevel panic -show_streams "198"$file | grep sample_rate | cut -f2 -d=`; "199"if [ $sample_rate -ne 16000 ]; then "200"ffmpeg -hide_banner -loglevel panic -y "201"-i $file -ar 16000 temp.wav; "202"mv temp.wav $file; "203"fi; done; done"204)205os.system(command)206207208# Split noise into chunks of 16,000 steps each209def load_noise_sample(path):210sample, sampling_rate = tf.audio.decode_wav(211tf.io.read_file(path), desired_channels=1212)213if sampling_rate == SAMPLING_RATE:214# Number of slices of 16000 each that can be generated from the noise sample215slices = int(sample.shape[0] / SAMPLING_RATE)216sample = tf.split(sample[: slices * SAMPLING_RATE], slices)217return sample218else:219print("Sampling rate for {} is incorrect. Ignoring it".format(path))220return None221222223noises = []224for path in noise_paths:225sample = load_noise_sample(path)226if sample:227noises.extend(sample)228noises = tf.stack(noises)229230print(231"{} noise files were split into {} noise samples where each is {} sec. long".format(232len(noise_paths), noises.shape[0], noises.shape[1] // SAMPLING_RATE233)234)235236"""237## Dataset generation238"""239240241def paths_and_labels_to_dataset(audio_paths, labels):242"""Constructs a dataset of audios and labels."""243path_ds = tf.data.Dataset.from_tensor_slices(audio_paths)244audio_ds = path_ds.map(245lambda x: path_to_audio(x), num_parallel_calls=tf.data.AUTOTUNE246)247label_ds = tf.data.Dataset.from_tensor_slices(labels)248return tf.data.Dataset.zip((audio_ds, label_ds))249250251def path_to_audio(path):252"""Reads and decodes an audio file."""253audio = tf.io.read_file(path)254audio, _ = tf.audio.decode_wav(audio, 1, SAMPLING_RATE)255return audio256257258def add_noise(audio, noises=None, scale=0.5):259if noises is not None:260# Create a random tensor of the same size as audio ranging from261# 0 to the number of noise stream samples that we have.262tf_rnd = tf.random.uniform(263(tf.shape(audio)[0],), 0, noises.shape[0], dtype=tf.int32264)265noise = tf.gather(noises, tf_rnd, axis=0)266267# Get the amplitude proportion between the audio and the noise268prop = tf.math.reduce_max(audio, axis=1) / tf.math.reduce_max(noise, axis=1)269prop = tf.repeat(tf.expand_dims(prop, axis=1), tf.shape(audio)[1], axis=1)270271# Adding the rescaled noise to audio272audio = audio + noise * prop * scale273274return audio275276277def audio_to_fft(audio):278# Since tf.signal.fft applies FFT on the innermost dimension,279# we need to squeeze the dimensions and then expand them again280# after FFT281audio = tf.squeeze(audio, axis=-1)282fft = tf.signal.fft(283tf.cast(tf.complex(real=audio, imag=tf.zeros_like(audio)), tf.complex64)284)285fft = tf.expand_dims(fft, axis=-1)286287# Return the absolute value of the first half of the FFT288# which represents the positive frequencies289return tf.math.abs(fft[:, : (audio.shape[1] // 2), :])290291292# Get the list of audio file paths along with their corresponding labels293294class_names = os.listdir(DATASET_AUDIO_PATH)295print(296"Our class names: {}".format(297class_names,298)299)300301audio_paths = []302labels = []303for label, name in enumerate(class_names):304print(305"Processing speaker {}".format(306name,307)308)309dir_path = Path(DATASET_AUDIO_PATH) / name310speaker_sample_paths = [311os.path.join(dir_path, filepath)312for filepath in os.listdir(dir_path)313if filepath.endswith(".wav")314]315audio_paths += speaker_sample_paths316labels += [label] * len(speaker_sample_paths)317318print(319"Found {} files belonging to {} classes.".format(len(audio_paths), len(class_names))320)321322# Shuffle323rng = np.random.RandomState(SHUFFLE_SEED)324rng.shuffle(audio_paths)325rng = np.random.RandomState(SHUFFLE_SEED)326rng.shuffle(labels)327328# Split into training and validation329num_val_samples = int(VALID_SPLIT * len(audio_paths))330print("Using {} files for training.".format(len(audio_paths) - num_val_samples))331train_audio_paths = audio_paths[:-num_val_samples]332train_labels = labels[:-num_val_samples]333334print("Using {} files for validation.".format(num_val_samples))335valid_audio_paths = audio_paths[-num_val_samples:]336valid_labels = labels[-num_val_samples:]337338# Create 2 datasets, one for training and the other for validation339train_ds = paths_and_labels_to_dataset(train_audio_paths, train_labels)340train_ds = train_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch(341BATCH_SIZE342)343344valid_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels)345valid_ds = valid_ds.shuffle(buffer_size=32 * 8, seed=SHUFFLE_SEED).batch(32)346347348# Add noise to the training set349train_ds = train_ds.map(350lambda x, y: (add_noise(x, noises, scale=SCALE), y),351num_parallel_calls=tf.data.AUTOTUNE,352)353354# Transform audio wave to the frequency domain using `audio_to_fft`355train_ds = train_ds.map(356lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE357)358train_ds = train_ds.prefetch(tf.data.AUTOTUNE)359360valid_ds = valid_ds.map(361lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE362)363valid_ds = valid_ds.prefetch(tf.data.AUTOTUNE)364365"""366## Model Definition367"""368369370def residual_block(x, filters, conv_num=3, activation="relu"):371# Shortcut372s = keras.layers.Conv1D(filters, 1, padding="same")(x)373for i in range(conv_num - 1):374x = keras.layers.Conv1D(filters, 3, padding="same")(x)375x = keras.layers.Activation(activation)(x)376x = keras.layers.Conv1D(filters, 3, padding="same")(x)377x = keras.layers.Add()([x, s])378x = keras.layers.Activation(activation)(x)379return keras.layers.MaxPool1D(pool_size=2, strides=2)(x)380381382def build_model(input_shape, num_classes):383inputs = keras.layers.Input(shape=input_shape, name="input")384385x = residual_block(inputs, 16, 2)386x = residual_block(x, 32, 2)387x = residual_block(x, 64, 3)388x = residual_block(x, 128, 3)389x = residual_block(x, 128, 3)390391x = keras.layers.AveragePooling1D(pool_size=3, strides=3)(x)392x = keras.layers.Flatten()(x)393x = keras.layers.Dense(256, activation="relu")(x)394x = keras.layers.Dense(128, activation="relu")(x)395396outputs = keras.layers.Dense(num_classes, activation="softmax", name="output")(x)397398return keras.models.Model(inputs=inputs, outputs=outputs)399400401model = build_model((SAMPLING_RATE // 2, 1), len(class_names))402403model.summary()404405# Compile the model using Adam's default learning rate406model.compile(407optimizer="Adam",408loss="sparse_categorical_crossentropy",409metrics=["accuracy"],410)411412# Add callbacks:413# 'EarlyStopping' to stop training when the model is not enhancing anymore414# 'ModelCheckPoint' to always keep the model that has the best val_accuracy415model_save_filename = "model.keras"416417earlystopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)418mdlcheckpoint_cb = keras.callbacks.ModelCheckpoint(419model_save_filename, monitor="val_accuracy", save_best_only=True420)421422"""423## Training424"""425426history = model.fit(427train_ds,428epochs=EPOCHS,429validation_data=valid_ds,430callbacks=[earlystopping_cb, mdlcheckpoint_cb],431)432433"""434## Evaluation435"""436437print(model.evaluate(valid_ds))438439"""440We get ~ 98% validation accuracy.441"""442443"""444## Demonstration445446Let's take some samples and:447448- Predict the speaker449- Compare the prediction with the real speaker450- Listen to the audio to see that despite the samples being noisy,451the model is still pretty accurate452"""453454SAMPLES_TO_DISPLAY = 10455456test_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels)457test_ds = test_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch(458BATCH_SIZE459)460461test_ds = test_ds.map(462lambda x, y: (add_noise(x, noises, scale=SCALE), y),463num_parallel_calls=tf.data.AUTOTUNE,464)465466for audios, labels in test_ds.take(1):467# Get the signal FFT468ffts = audio_to_fft(audios)469# Predict470y_pred = model.predict(ffts)471# Take random samples472rnd = np.random.randint(0, BATCH_SIZE, SAMPLES_TO_DISPLAY)473audios = audios.numpy()[rnd, :, :]474labels = labels.numpy()[rnd]475y_pred = np.argmax(y_pred, axis=-1)[rnd]476477for index in range(SAMPLES_TO_DISPLAY):478# For every sample, print the true and predicted label479# as well as run the voice with the noise480print(481"Speaker:\33{} {}\33[0m\tPredicted:\33{} {}\33[0m".format(482"[92m" if labels[index] == y_pred[index] else "[91m",483class_names[labels[index]],484"[92m" if labels[index] == y_pred[index] else "[91m",485class_names[y_pred[index]],486)487)488display(Audio(audios[index, :, :].squeeze(), rate=SAMPLING_RATE))489490491