Path: blob/master/examples/timeseries/eeg_signal_classification.py
3507 views
"""1Title: Electroencephalogram Signal Classification for action identification2Author: [Suvaditya Mukherjee](https://github.com/suvadityamuk)3Date created: 2022/11/034Last modified: 2022/11/055Description: Training a Convolutional model to classify EEG signals produced by exposure to certain stimuli.6Accelerator: GPU7"""89"""10## Introduction1112The following example explores how we can make a Convolution-based Neural Network to13perform classification on Electroencephalogram signals captured when subjects were14exposed to different stimuli.15We train a model from scratch since such signal-classification models are fairly scarce16in pre-trained format.17The data we use is sourced from the UC Berkeley-Biosense Lab where the data was collected18from 15 subjects at the same time.19Our process is as follows:2021- Load the [UC Berkeley-Biosense Synchronized Brainwave Dataset](https://www.kaggle.com/datasets/berkeley-biosense/synchronized-brainwave-dataset)22- Visualize random samples from the data23- Pre-process, collate and scale the data to finally make a `tf.data.Dataset`24- Prepare class weights in order to tackle major imbalances25- Create a Conv1D and Dense-based model to perform classification26- Define callbacks and hyperparameters27- Train the model28- Plot metrics from History and perform evaluation2930This example needs the following external dependencies (Gdown, Scikit-learn, Pandas,31Numpy, Matplotlib). You can install it via the following commands.3233Gdown is an external package used to download large files from Google Drive. To know34more, you can refer to its [PyPi page here](https://pypi.org/project/gdown)35"""363738"""39## Setup and Data Downloads4041First, lets install our dependencies:42"""4344"""shell45pip install gdown -q46pip install scikit-learn -q47pip install pandas -q48pip install numpy -q49pip install matplotlib -q50"""5152"""53Next, lets download our dataset.54The gdown package makes it easy to download the data from Google Drive:55"""5657"""shell58gdown 1V5B7Bt6aJm0UHbR7cRKBEK8jx7lYPVuX59# gdown will download eeg-data.csv onto the local drive for use. Total size of60# eeg-data.csv is 105.7 MB61"""6263import pandas as pd64import matplotlib.pyplot as plt65import json66import numpy as np67import keras68from keras import layers69import tensorflow as tf70from sklearn import preprocessing, model_selection71import random7273QUALITY_THRESHOLD = 12874BATCH_SIZE = 6475SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 27677"""78## Read data from `eeg-data.csv`7980We use the Pandas library to read the `eeg-data.csv` file and display the first 5 rows81using the `.head()` command82"""8384eeg = pd.read_csv("eeg-data.csv")8586"""87We remove unlabeled samples from our dataset as they do not contribute to the model. We88also perform a `.drop()` operation on the columns that are not required for training data89preparation90"""9192unlabeled_eeg = eeg[eeg["label"] == "unlabeled"]93eeg = eeg.loc[eeg["label"] != "unlabeled"]94eeg = eeg.loc[eeg["label"] != "everyone paired"]9596eeg.drop(97[98"indra_time",99"Unnamed: 0",100"browser_latency",101"reading_time",102"attention_esense",103"meditation_esense",104"updatedAt",105"createdAt",106],107axis=1,108inplace=True,109)110111eeg.reset_index(drop=True, inplace=True)112eeg.head()113114"""115In the data, the samples recorded are given a score from 0 to 128 based on how116well-calibrated the sensor was (0 being best, 200 being worst). We filter the values117based on an arbitrary cutoff limit of 128.118"""119120121def convert_string_data_to_values(value_string):122str_list = json.loads(value_string)123return str_list124125126eeg["raw_values"] = eeg["raw_values"].apply(convert_string_data_to_values)127128eeg = eeg.loc[eeg["signal_quality"] < QUALITY_THRESHOLD]129eeg.head()130131"""132## Visualize one random sample from the data133"""134135"""136We visualize one sample from the data to understand how the stimulus-induced signal looks137like138"""139140141def view_eeg_plot(idx):142data = eeg.loc[idx, "raw_values"]143plt.plot(data)144plt.title(f"Sample random plot")145plt.show()146147148view_eeg_plot(7)149150"""151## Pre-process and collate data152"""153154"""155There are a total of 67 different labels present in the data, where there are numbered156sub-labels. We collate them under a single label as per their numbering and replace them157in the data itself. Following this process, we perform simple Label encoding to get them158in an integer format.159"""160161print("Before replacing labels")162print(eeg["label"].unique(), "\n")163print(len(eeg["label"].unique()), "\n")164165166eeg.replace(167{168"label": {169"blink1": "blink",170"blink2": "blink",171"blink3": "blink",172"blink4": "blink",173"blink5": "blink",174"math1": "math",175"math2": "math",176"math3": "math",177"math4": "math",178"math5": "math",179"math6": "math",180"math7": "math",181"math8": "math",182"math9": "math",183"math10": "math",184"math11": "math",185"math12": "math",186"thinkOfItems-ver1": "thinkOfItems",187"thinkOfItems-ver2": "thinkOfItems",188"video-ver1": "video",189"video-ver2": "video",190"thinkOfItemsInstruction-ver1": "thinkOfItemsInstruction",191"thinkOfItemsInstruction-ver2": "thinkOfItemsInstruction",192"colorRound1-1": "colorRound1",193"colorRound1-2": "colorRound1",194"colorRound1-3": "colorRound1",195"colorRound1-4": "colorRound1",196"colorRound1-5": "colorRound1",197"colorRound1-6": "colorRound1",198"colorRound2-1": "colorRound2",199"colorRound2-2": "colorRound2",200"colorRound2-3": "colorRound2",201"colorRound2-4": "colorRound2",202"colorRound2-5": "colorRound2",203"colorRound2-6": "colorRound2",204"colorRound3-1": "colorRound3",205"colorRound3-2": "colorRound3",206"colorRound3-3": "colorRound3",207"colorRound3-4": "colorRound3",208"colorRound3-5": "colorRound3",209"colorRound3-6": "colorRound3",210"colorRound4-1": "colorRound4",211"colorRound4-2": "colorRound4",212"colorRound4-3": "colorRound4",213"colorRound4-4": "colorRound4",214"colorRound4-5": "colorRound4",215"colorRound4-6": "colorRound4",216"colorRound5-1": "colorRound5",217"colorRound5-2": "colorRound5",218"colorRound5-3": "colorRound5",219"colorRound5-4": "colorRound5",220"colorRound5-5": "colorRound5",221"colorRound5-6": "colorRound5",222"colorInstruction1": "colorInstruction",223"colorInstruction2": "colorInstruction",224"readyRound1": "readyRound",225"readyRound2": "readyRound",226"readyRound3": "readyRound",227"readyRound4": "readyRound",228"readyRound5": "readyRound",229"colorRound1": "colorRound",230"colorRound2": "colorRound",231"colorRound3": "colorRound",232"colorRound4": "colorRound",233"colorRound5": "colorRound",234}235},236inplace=True,237)238239print("After replacing labels")240print(eeg["label"].unique())241print(len(eeg["label"].unique()))242243le = preprocessing.LabelEncoder() # Generates a look-up table244le.fit(eeg["label"])245eeg["label"] = le.transform(eeg["label"])246247"""248We extract the number of unique classes present in the data249"""250251num_classes = len(eeg["label"].unique())252print(num_classes)253254"""255We now visualize the number of samples present in each class using a Bar plot.256"""257258plt.bar(range(num_classes), eeg["label"].value_counts())259plt.title("Number of samples per class")260plt.show()261262"""263## Scale and split data264"""265266"""267We perform a simple Min-Max scaling to bring the value-range between 0 and 1. We do not268use Standard Scaling as the data does not follow a Gaussian distribution.269"""270271scaler = preprocessing.MinMaxScaler()272series_list = [273scaler.fit_transform(np.asarray(i).reshape(-1, 1)) for i in eeg["raw_values"]274]275276labels_list = [i for i in eeg["label"]]277278"""279We now create a Train-test split with a 15% holdout set. Following this, we reshape the280data to create a sequence of length 512. We also convert the labels from their current281label-encoded form to a one-hot encoding to enable use of several different282`keras.metrics` functions.283"""284285x_train, x_test, y_train, y_test = model_selection.train_test_split(286series_list, labels_list, test_size=0.15, random_state=42, shuffle=True287)288289print(290f"Length of x_train : {len(x_train)}\nLength of x_test : {len(x_test)}\nLength of y_train : {len(y_train)}\nLength of y_test : {len(y_test)}"291)292293x_train = np.asarray(x_train).astype(np.float32).reshape(-1, 512, 1)294y_train = np.asarray(y_train).astype(np.float32).reshape(-1, 1)295y_train = keras.utils.to_categorical(y_train)296297x_test = np.asarray(x_test).astype(np.float32).reshape(-1, 512, 1)298y_test = np.asarray(y_test).astype(np.float32).reshape(-1, 1)299y_test = keras.utils.to_categorical(y_test)300301"""302## Prepare `tf.data.Dataset`303"""304305"""306We now create a `tf.data.Dataset` from this data to prepare it for training. We also307shuffle and batch the data for use later.308"""309310train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))311test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))312313train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)314test_dataset = test_dataset.batch(BATCH_SIZE)315316"""317## Make Class Weights using Naive method318"""319320"""321As we can see from the plot of number of samples per class, the dataset is imbalanced.322Hence, we **calculate weights for each class** to make sure that the model is trained in323a fair manner without preference to any specific class due to greater number of samples.324325We use a naive method to calculate these weights, finding an **inverse proportion** of326each class and using that as the weight.327"""328329vals_dict = {}330for i in eeg["label"]:331if i in vals_dict.keys():332vals_dict[i] += 1333else:334vals_dict[i] = 1335total = sum(vals_dict.values())336337# Formula used - Naive method where338# weight = 1 - (no. of samples present / total no. of samples)339# So more the samples, lower the weight340341weight_dict = {k: (1 - (v / total)) for k, v in vals_dict.items()}342print(weight_dict)343344"""345## Define simple function to plot all the metrics present in a `keras.callbacks.History`346object347"""348349350def plot_history_metrics(history: keras.callbacks.History):351total_plots = len(history.history)352cols = total_plots // 2353354rows = total_plots // cols355356if total_plots % cols != 0:357rows += 1358359pos = range(1, total_plots + 1)360plt.figure(figsize=(15, 10))361for i, (key, value) in enumerate(history.history.items()):362plt.subplot(rows, cols, pos[i])363plt.plot(range(len(value)), value)364plt.title(str(key))365plt.show()366367368"""369## Define function to generate Convolutional model370"""371372373def create_model():374input_layer = keras.Input(shape=(512, 1))375376x = layers.Conv1D(377filters=32, kernel_size=3, strides=2, activation="relu", padding="same"378)(input_layer)379x = layers.BatchNormalization()(x)380381x = layers.Conv1D(382filters=64, kernel_size=3, strides=2, activation="relu", padding="same"383)(x)384x = layers.BatchNormalization()(x)385386x = layers.Conv1D(387filters=128, kernel_size=5, strides=2, activation="relu", padding="same"388)(x)389x = layers.BatchNormalization()(x)390391x = layers.Conv1D(392filters=256, kernel_size=5, strides=2, activation="relu", padding="same"393)(x)394x = layers.BatchNormalization()(x)395396x = layers.Conv1D(397filters=512, kernel_size=7, strides=2, activation="relu", padding="same"398)(x)399x = layers.BatchNormalization()(x)400401x = layers.Conv1D(402filters=1024,403kernel_size=7,404strides=2,405activation="relu",406padding="same",407)(x)408x = layers.BatchNormalization()(x)409410x = layers.Dropout(0.2)(x)411412x = layers.Flatten()(x)413414x = layers.Dense(4096, activation="relu")(x)415x = layers.Dropout(0.2)(x)416417x = layers.Dense(4182048, activation="relu", kernel_regularizer=keras.regularizers.L2()419)(x)420x = layers.Dropout(0.2)(x)421422x = layers.Dense(4231024, activation="relu", kernel_regularizer=keras.regularizers.L2()424)(x)425x = layers.Dropout(0.2)(x)426x = layers.Dense(427128, activation="relu", kernel_regularizer=keras.regularizers.L2()428)(x)429output_layer = layers.Dense(num_classes, activation="softmax")(x)430431return keras.Model(inputs=input_layer, outputs=output_layer)432433434"""435## Get Model summary436"""437438conv_model = create_model()439conv_model.summary()440441"""442## Define callbacks, optimizer, loss and metrics443"""444445"""446We set the number of epochs at 30 after performing extensive experimentation. It was seen447that this was the optimal number, after performing Early-Stopping analysis as well.448We define a Model Checkpoint callback to make sure that we only get the best model449weights.450We also define a ReduceLROnPlateau as there were several cases found during451experimentation where the loss stagnated after a certain point. On the other hand, a452direct LRScheduler was found to be too aggressive in its decay.453"""454455epochs = 30456457callbacks = [458keras.callbacks.ModelCheckpoint(459"best_model.keras", save_best_only=True, monitor="loss"460),461keras.callbacks.ReduceLROnPlateau(462monitor="val_top_k_categorical_accuracy",463factor=0.2,464patience=2,465min_lr=0.000001,466),467]468469optimizer = keras.optimizers.Adam(amsgrad=True, learning_rate=0.001)470loss = keras.losses.CategoricalCrossentropy()471472"""473## Compile model and call `model.fit()`474"""475476"""477We use the `Adam` optimizer since it is commonly considered the best choice for478preliminary training, and was found to be the best optimizer.479We use `CategoricalCrossentropy` as the loss as our labels are in a one-hot-encoded form.480481We define the `TopKCategoricalAccuracy(k=3)`, `AUC`, `Precision` and `Recall` metrics to482further aid in understanding the model better.483"""484485conv_model.compile(486optimizer=optimizer,487loss=loss,488metrics=[489keras.metrics.TopKCategoricalAccuracy(k=3),490keras.metrics.AUC(),491keras.metrics.Precision(),492keras.metrics.Recall(),493],494)495496conv_model_history = conv_model.fit(497train_dataset,498epochs=epochs,499callbacks=callbacks,500validation_data=test_dataset,501class_weight=weight_dict,502)503504"""505## Visualize model metrics during training506"""507508"""509We use the function defined above to see model metrics during training.510"""511512plot_history_metrics(conv_model_history)513514"""515## Evaluate model on test data516"""517518loss, accuracy, auc, precision, recall = conv_model.evaluate(test_dataset)519print(f"Loss : {loss}")520print(f"Top 3 Categorical Accuracy : {accuracy}")521print(f"Area under the Curve (ROC) : {auc}")522print(f"Precision : {precision}")523print(f"Recall : {recall}")524525526def view_evaluated_eeg_plots(model):527start_index = random.randint(10, len(eeg))528end_index = start_index + 11529data = eeg.loc[start_index:end_index, "raw_values"]530data_array = [scaler.fit_transform(np.asarray(i).reshape(-1, 1)) for i in data]531data_array = [np.asarray(data_array).astype(np.float32).reshape(-1, 512, 1)]532original_labels = eeg.loc[start_index:end_index, "label"]533predicted_labels = np.argmax(model.predict(data_array, verbose=0), axis=1)534original_labels = [535le.inverse_transform(np.array(label).reshape(-1))[0]536for label in original_labels537]538predicted_labels = [539le.inverse_transform(np.array(label).reshape(-1))[0]540for label in predicted_labels541]542total_plots = 12543cols = total_plots // 3544rows = total_plots // cols545if total_plots % cols != 0:546rows += 1547pos = range(1, total_plots + 1)548fig = plt.figure(figsize=(20, 10))549for i, (plot_data, og_label, pred_label) in enumerate(550zip(data, original_labels, predicted_labels)551):552plt.subplot(rows, cols, pos[i])553plt.plot(plot_data)554plt.title(f"Actual Label : {og_label}\nPredicted Label : {pred_label}")555fig.subplots_adjust(hspace=0.5)556plt.show()557558559view_evaluated_eeg_plots(conv_model)560561562