Path: blob/master/examples/audio/uk_ireland_accent_recognition.py
3507 views
"""1Title: English speaker accent recognition using Transfer Learning2Author: [Fadi Badine](https://twitter.com/fadibadine)3Date created: 2022/04/164Last modified: 2022/04/165Description: Training a model to classify UK & Ireland accents using feature extraction from Yamnet.6Accelerator: GPU7"""89"""10## Introduction1112The following example shows how to use feature extraction in order to13train a model to classify the English accent spoken in an audio wave.1415Instead of training a model from scratch, transfer learning enables us to16take advantage of existing state-of-the-art deep learning models and use them as feature extractors.1718Our process:1920* Use a TF Hub pre-trained model (Yamnet) and apply it as part of the tf.data pipeline which transforms21the audio files into feature vectors.22* Train a dense model on the feature vectors.23* Use the trained model for inference on a new audio file.2425Note:2627* We need to install TensorFlow IO in order to resample audio files to 16 kHz as required by Yamnet model.28* In the test section, ffmpeg is used to convert the mp3 file to wav.2930You can install TensorFlow IO with the following command:31"""3233"""shell34pip install -U -q tensorflow_io35"""3637"""38## Configuration39"""4041SEED = 133742EPOCHS = 10043BATCH_SIZE = 6444VALIDATION_RATIO = 0.145MODEL_NAME = "uk_irish_accent_recognition"4647# Location where the dataset will be downloaded.48# By default (None), keras.utils.get_file will use ~/.keras/ as the CACHE_DIR49CACHE_DIR = None5051# The location of the dataset52URL_PATH = "https://www.openslr.org/resources/83/"5354# List of datasets compressed files that contain the audio files55zip_files = {560: "irish_english_male.zip",571: "midlands_english_female.zip",582: "midlands_english_male.zip",593: "northern_english_female.zip",604: "northern_english_male.zip",615: "scottish_english_female.zip",626: "scottish_english_male.zip",637: "southern_english_female.zip",648: "southern_english_male.zip",659: "welsh_english_female.zip",6610: "welsh_english_male.zip",67}6869# We see that there are 2 compressed files for each accent (except Irish):70# - One for male speakers71# - One for female speakers72# However, we will be using a gender agnostic dataset.7374# List of gender agnostic categories75gender_agnostic_categories = [76"ir", # Irish77"mi", # Midlands78"no", # Northern79"sc", # Scottish80"so", # Southern81"we", # Welsh82]8384class_names = [85"Irish",86"Midlands",87"Northern",88"Scottish",89"Southern",90"Welsh",91"Not a speech",92]9394"""95## Imports96"""9798import os99import io100import csv101import numpy as np102import pandas as pd103import tensorflow as tf104import tensorflow_hub as hub105import tensorflow_io as tfio106from tensorflow import keras107import matplotlib.pyplot as plt108import seaborn as sns109from scipy import stats110from IPython.display import Audio111112113# Set all random seeds in order to get reproducible results114keras.utils.set_random_seed(SEED)115116# Where to download the dataset117DATASET_DESTINATION = os.path.join(CACHE_DIR if CACHE_DIR else "~/.keras/", "datasets")118119"""120## Yamnet Model121122Yamnet is an audio event classifier trained on the AudioSet dataset to predict audio123events from the AudioSet ontology. It is available on TensorFlow Hub.124125Yamnet accepts a 1-D tensor of audio samples with a sample rate of 16 kHz.126As output, the model returns a 3-tuple:127128* Scores of shape `(N, 521)` representing the scores of the 521 classes.129* Embeddings of shape `(N, 1024)`.130* The log-mel spectrogram of the entire audio frame.131132We will use the embeddings, which are the features extracted from the audio samples, as the input to our dense model.133134For more detailed information about Yamnet, please refer to its [TensorFlow Hub](https://tfhub.dev/google/yamnet/1) page.135"""136137yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")138139"""140## Dataset141142The dataset used is the143[Crowdsourced high-quality UK and Ireland English Dialect speech data set](https://openslr.org/83/)144which consists of a total of 17,877 high-quality audio wav files.145146This dataset includes over 31 hours of recording from 120 volunteers who self-identify as147native speakers of Southern England, Midlands, Northern England, Wales, Scotland and Ireland.148149For more info, please refer to the above link or to the following paper:150[Open-source Multi-speaker Corpora of the English Accents in the British Isles](https://aclanthology.org/2020.lrec-1.804.pdf)151"""152153"""154## Download the data155"""156157# CSV file that contains information about the dataset. For each entry, we have:158# - ID159# - wav file name160# - transcript161line_index_file = keras.utils.get_file(162fname="line_index_file", origin=URL_PATH + "line_index_all.csv"163)164165# Download the list of compressed files that contain the audio wav files166for i in zip_files:167fname = zip_files[i].split(".")[0]168url = URL_PATH + zip_files[i]169170zip_file = keras.utils.get_file(fname=fname, origin=url, extract=True)171os.remove(zip_file)172173"""174## Load the data in a Dataframe175176Of the 3 columns (ID, filename and transcript), we are only interested in the filename column in order to read the audio file.177We will ignore the other two.178"""179180dataframe = pd.read_csv(181line_index_file, names=["id", "filename", "transcript"], usecols=["filename"]182)183dataframe.head()184185"""186Let's now preprocess the dataset by:187188* Adjusting the filename (removing a leading space & adding ".wav" extension to the189filename).190* Creating a label using the first 2 characters of the filename which indicate the191accent.192* Shuffling the samples.193"""194195196# The purpose of this function is to preprocess the dataframe by applying the following:197# - Cleaning the filename from a leading space198# - Generating a label column that is gender agnostic i.e.199# welsh english male and welsh english female for example are both labeled as200# welsh english201# - Add extension .wav to the filename202# - Shuffle samples203def preprocess_dataframe(dataframe):204# Remove leading space in filename column205dataframe["filename"] = dataframe.apply(lambda row: row["filename"].strip(), axis=1)206207# Create gender agnostic labels based on the filename first 2 letters208dataframe["label"] = dataframe.apply(209lambda row: gender_agnostic_categories.index(row["filename"][:2]), axis=1210)211212# Add the file path to the name213dataframe["filename"] = dataframe.apply(214lambda row: os.path.join(DATASET_DESTINATION, row["filename"] + ".wav"), axis=1215)216217# Shuffle the samples218dataframe = dataframe.sample(frac=1, random_state=SEED).reset_index(drop=True)219220return dataframe221222223dataframe = preprocess_dataframe(dataframe)224dataframe.head()225226"""227## Prepare training & validation sets228229Let's split the samples creating training and validation sets.230"""231232split = int(len(dataframe) * (1 - VALIDATION_RATIO))233train_df = dataframe[:split]234valid_df = dataframe[split:]235236print(237f"We have {train_df.shape[0]} training samples & {valid_df.shape[0]} validation ones"238)239240"""241## Prepare a TensorFlow Dataset242243Next, we need to create a `tf.data.Dataset`.244This is done by creating a `dataframe_to_dataset` function that does the following:245246* Create a dataset using filenames and labels.247* Get the Yamnet embeddings by calling another function `filepath_to_embeddings`.248* Apply caching, reshuffling and setting batch size.249250The `filepath_to_embeddings` does the following:251252* Load audio file.253* Resample audio to 16 kHz.254* Generate scores and embeddings from Yamnet model.255* Since Yamnet generates multiple samples for each audio file,256this function also duplicates the label for all the generated samples257that have `score=0` (speech) whereas sets the label for the others as258'other' indicating that this audio segment is not a speech and we won't label it as one of the accents.259260The below `load_16k_audio_file` is copied from the following tutorial261[Transfer learning with YAMNet for environmental sound classification](https://www.tensorflow.org/tutorials/audio/transfer_learning_audio)262"""263264265@tf.function266def load_16k_audio_wav(filename):267# Read file content268file_content = tf.io.read_file(filename)269270# Decode audio wave271audio_wav, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1)272audio_wav = tf.squeeze(audio_wav, axis=-1)273sample_rate = tf.cast(sample_rate, dtype=tf.int64)274275# Resample to 16k276audio_wav = tfio.audio.resample(audio_wav, rate_in=sample_rate, rate_out=16000)277278return audio_wav279280281def filepath_to_embeddings(filename, label):282# Load 16k audio wave283audio_wav = load_16k_audio_wav(filename)284285# Get audio embeddings & scores.286# The embeddings are the audio features extracted using transfer learning287# while scores will be used to identify time slots that are not speech288# which will then be gathered into a specific new category 'other'289scores, embeddings, _ = yamnet_model(audio_wav)290291# Number of embeddings in order to know how many times to repeat the label292embeddings_num = tf.shape(embeddings)[0]293labels = tf.repeat(label, embeddings_num)294295# Change labels for time-slots that are not speech into a new category 'other'296labels = tf.where(tf.argmax(scores, axis=1) == 0, label, len(class_names) - 1)297298# Using one-hot in order to use AUC299return (embeddings, tf.one_hot(labels, len(class_names)))300301302def dataframe_to_dataset(dataframe, batch_size=64):303dataset = tf.data.Dataset.from_tensor_slices(304(dataframe["filename"], dataframe["label"])305)306307dataset = dataset.map(308lambda x, y: filepath_to_embeddings(x, y),309num_parallel_calls=tf.data.experimental.AUTOTUNE,310).unbatch()311312return dataset.cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)313314315train_ds = dataframe_to_dataset(train_df)316valid_ds = dataframe_to_dataset(valid_df)317318"""319## Build the model320321The model that we use consists of:322323* An input layer which is the embedding output of the Yamnet classifier.324* 4 dense hidden layers and 4 dropout layers.325* An output dense layer.326327The model's hyperparameters were selected using328[KerasTuner](https://keras.io/keras_tuner/).329"""330331keras.backend.clear_session()332333334def build_and_compile_model():335inputs = keras.layers.Input(shape=(1024), name="embedding")336337x = keras.layers.Dense(256, activation="relu", name="dense_1")(inputs)338x = keras.layers.Dropout(0.15, name="dropout_1")(x)339340x = keras.layers.Dense(384, activation="relu", name="dense_2")(x)341x = keras.layers.Dropout(0.2, name="dropout_2")(x)342343x = keras.layers.Dense(192, activation="relu", name="dense_3")(x)344x = keras.layers.Dropout(0.25, name="dropout_3")(x)345346x = keras.layers.Dense(384, activation="relu", name="dense_4")(x)347x = keras.layers.Dropout(0.2, name="dropout_4")(x)348349outputs = keras.layers.Dense(len(class_names), activation="softmax", name="ouput")(350x351)352353model = keras.Model(inputs=inputs, outputs=outputs, name="accent_recognition")354355model.compile(356optimizer=keras.optimizers.Adam(learning_rate=1.9644e-5),357loss=keras.losses.CategoricalCrossentropy(),358metrics=["accuracy", keras.metrics.AUC(name="auc")],359)360361return model362363364model = build_and_compile_model()365model.summary()366367"""368## Class weights calculation369370Since the dataset is quite unbalanced, we will use `class_weight` argument during training.371372Getting the class weights is a little tricky because even though we know the number of373audio files for each class, it does not represent the number of samples for that class374since Yamnet transforms each audio file into multiple audio samples of 0.96 seconds each.375So every audio file will be split into a number of samples that is proportional to its length.376377Therefore, to get those weights, we have to calculate the number of samples for each class378after preprocessing through Yamnet.379"""380381class_counts = tf.zeros(shape=(len(class_names),), dtype=tf.int32)382383for x, y in iter(train_ds):384class_counts = class_counts + tf.math.bincount(385tf.cast(tf.math.argmax(y, axis=1), tf.int32), minlength=len(class_names)386)387388class_weight = {389i: tf.math.reduce_sum(class_counts).numpy() / class_counts[i].numpy()390for i in range(len(class_counts))391}392393print(class_weight)394395"""396## Callbacks397398We use Keras callbacks in order to:399400* Stop whenever the validation AUC stops improving.401* Save the best model.402* Call TensorBoard in order to later view the training and validation logs.403"""404405early_stopping_cb = keras.callbacks.EarlyStopping(406monitor="val_auc", patience=10, restore_best_weights=True407)408409model_checkpoint_cb = keras.callbacks.ModelCheckpoint(410MODEL_NAME + ".h5", monitor="val_auc", save_best_only=True411)412413tensorboard_cb = keras.callbacks.TensorBoard(414os.path.join(os.curdir, "logs", model.name)415)416417callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]418419"""420## Training421"""422423history = model.fit(424train_ds,425epochs=EPOCHS,426validation_data=valid_ds,427class_weight=class_weight,428callbacks=callbacks,429verbose=2,430)431432"""433## Results434435Let's plot the training and validation AUC and accuracy.436"""437438fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 5))439440axs[0].plot(range(EPOCHS), history.history["accuracy"], label="Training")441axs[0].plot(range(EPOCHS), history.history["val_accuracy"], label="Validation")442axs[0].set_xlabel("Epochs")443axs[0].set_title("Training & Validation Accuracy")444axs[0].legend()445axs[0].grid(True)446447axs[1].plot(range(EPOCHS), history.history["auc"], label="Training")448axs[1].plot(range(EPOCHS), history.history["val_auc"], label="Validation")449axs[1].set_xlabel("Epochs")450axs[1].set_title("Training & Validation AUC")451axs[1].legend()452axs[1].grid(True)453454plt.show()455456"""457## Evaluation458"""459460train_loss, train_acc, train_auc = model.evaluate(train_ds)461valid_loss, valid_acc, valid_auc = model.evaluate(valid_ds)462463"""464Let's try to compare our model's performance to Yamnet's using one of Yamnet metrics (d-prime)465Yamnet achieved a d-prime value of 2.318.466Let's check our model's performance.467"""468469470# The following function calculates the d-prime score from the AUC471def d_prime(auc):472standard_normal = stats.norm()473d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)474return d_prime475476477print(478"train d-prime: {0:.3f}, validation d-prime: {1:.3f}".format(479d_prime(train_auc), d_prime(valid_auc)480)481)482483"""484We can see that the model achieves the following results:485486Results | Training | Validation487-----------|-----------|------------488Accuracy | 54% | 51%489AUC | 0.91 | 0.89490d-prime | 1.882 | 1.740491492"""493494"""495## Confusion Matrix496497Let's now plot the confusion matrix for the validation dataset.498499The confusion matrix lets us see, for every class, not only how many samples were correctly classified,500but also which other classes were the samples confused with.501502It allows us to calculate the precision and recall for every class.503"""504505# Create x and y tensors506x_valid = None507y_valid = None508509for x, y in iter(valid_ds):510if x_valid is None:511x_valid = x.numpy()512y_valid = y.numpy()513else:514x_valid = np.concatenate((x_valid, x.numpy()), axis=0)515y_valid = np.concatenate((y_valid, y.numpy()), axis=0)516517# Generate predictions518y_pred = model.predict(x_valid)519520# Calculate confusion matrix521confusion_mtx = tf.math.confusion_matrix(522np.argmax(y_valid, axis=1), np.argmax(y_pred, axis=1)523)524525# Plot the confusion matrix526plt.figure(figsize=(10, 8))527sns.heatmap(528confusion_mtx, xticklabels=class_names, yticklabels=class_names, annot=True, fmt="g"529)530plt.xlabel("Prediction")531plt.ylabel("Label")532plt.title("Validation Confusion Matrix")533plt.show()534535"""536## Precision & recall537538For every class:539540* Recall is the ratio of correctly classified samples i.e. it shows how many samples541of this specific class, the model is able to detect.542It is the ratio of diagonal elements to the sum of all elements in the row.543* Precision shows the accuracy of the classifier. It is the ratio of correctly predicted544samples among the ones classified as belonging to this class.545It is the ratio of diagonal elements to the sum of all elements in the column.546"""547548for i, label in enumerate(class_names):549precision = confusion_mtx[i, i] / np.sum(confusion_mtx[:, i])550recall = confusion_mtx[i, i] / np.sum(confusion_mtx[i, :])551print(552"{0:15} Precision:{1:.2f}%; Recall:{2:.2f}%".format(553label, precision * 100, recall * 100554)555)556557"""558## Run inference on test data559560Let's now run a test on a single audio file.561Let's check this example from [The Scottish Voice](https://www.thescottishvoice.org.uk/home/)562563We will:564565* Download the mp3 file.566* Convert it to a 16k wav file.567* Run the model on the wav file.568* Plot the results.569"""570571filename = "audio-sample-Stuart"572url = "https://www.thescottishvoice.org.uk/files/cm/files/"573574if os.path.exists(filename + ".wav") == False:575print(f"Downloading {filename}.mp3 from {url}")576command = f"wget {url}{filename}.mp3"577os.system(command)578579print(f"Converting mp3 to wav and resampling to 16 kHZ")580command = (581f"ffmpeg -hide_banner -loglevel panic -y -i {filename}.mp3 -acodec "582f"pcm_s16le -ac 1 -ar 16000 {filename}.wav"583)584os.system(command)585586filename = filename + ".wav"587588589"""590The below function `yamnet_class_names_from_csv` was copied and very slightly changed591from this [Yamnet Notebook](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/yamnet.ipynb).592"""593594595def yamnet_class_names_from_csv(yamnet_class_map_csv_text):596"""Returns list of class names corresponding to score vector."""597yamnet_class_map_csv = io.StringIO(yamnet_class_map_csv_text)598yamnet_class_names = [599name for (class_index, mid, name) in csv.reader(yamnet_class_map_csv)600]601yamnet_class_names = yamnet_class_names[1:] # Skip CSV header602return yamnet_class_names603604605yamnet_class_map_path = yamnet_model.class_map_path().numpy()606yamnet_class_names = yamnet_class_names_from_csv(607tf.io.read_file(yamnet_class_map_path).numpy().decode("utf-8")608)609610611def calculate_number_of_non_speech(scores):612number_of_non_speech = tf.math.reduce_sum(613tf.where(tf.math.argmax(scores, axis=1, output_type=tf.int32) != 0, 1, 0)614)615616return number_of_non_speech617618619def filename_to_predictions(filename):620# Load 16k audio wave621audio_wav = load_16k_audio_wav(filename)622623# Get audio embeddings & scores.624scores, embeddings, mel_spectrogram = yamnet_model(audio_wav)625626print(627"Out of {} samples, {} are not speech".format(628scores.shape[0], calculate_number_of_non_speech(scores)629)630)631632# Predict the output of the accent recognition model with embeddings as input633predictions = model.predict(embeddings)634635return audio_wav, predictions, mel_spectrogram636637638"""639Let's run the model on the audio file:640"""641642audio_wav, predictions, mel_spectrogram = filename_to_predictions(filename)643644infered_class = class_names[predictions.mean(axis=0).argmax()]645print(f"The main accent is: {infered_class} English")646647"""648Listen to the audio649"""650651Audio(audio_wav, rate=16000)652653"""654The below function was copied from this [Yamnet notebook](tinyurl.com/4a8xn7at) and adjusted to our need.655656This function plots the following:657658* Audio waveform659* Mel spectrogram660* Predictions for every time step661"""662663plt.figure(figsize=(10, 6))664665# Plot the waveform.666plt.subplot(3, 1, 1)667plt.plot(audio_wav)668plt.xlim([0, len(audio_wav)])669670# Plot the log-mel spectrogram (returned by the model).671plt.subplot(3, 1, 2)672plt.imshow(673mel_spectrogram.numpy().T, aspect="auto", interpolation="nearest", origin="lower"674)675676# Plot and label the model output scores for the top-scoring classes.677mean_predictions = np.mean(predictions, axis=0)678679top_class_indices = np.argsort(mean_predictions)[::-1]680plt.subplot(3, 1, 3)681plt.imshow(682predictions[:, top_class_indices].T,683aspect="auto",684interpolation="nearest",685cmap="gray_r",686)687688# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS689# values from the model documentation690patch_padding = (0.025 / 2) / 0.01691plt.xlim([-patch_padding - 0.5, predictions.shape[0] + patch_padding - 0.5])692# Label the top_N classes.693yticks = range(0, len(class_names), 1)694plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])695_ = plt.ylim(-0.5 + np.array([len(class_names), 0]))696697698