Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/timeseries/eeg_signal_classification.py
3507 views
1
"""
2
Title: Electroencephalogram Signal Classification for action identification
3
Author: [Suvaditya Mukherjee](https://github.com/suvadityamuk)
4
Date created: 2022/11/03
5
Last modified: 2022/11/05
6
Description: Training a Convolutional model to classify EEG signals produced by exposure to certain stimuli.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The following example explores how we can make a Convolution-based Neural Network to
14
perform classification on Electroencephalogram signals captured when subjects were
15
exposed to different stimuli.
16
We train a model from scratch since such signal-classification models are fairly scarce
17
in pre-trained format.
18
The data we use is sourced from the UC Berkeley-Biosense Lab where the data was collected
19
from 15 subjects at the same time.
20
Our process is as follows:
21
22
- Load the [UC Berkeley-Biosense Synchronized Brainwave Dataset](https://www.kaggle.com/datasets/berkeley-biosense/synchronized-brainwave-dataset)
23
- Visualize random samples from the data
24
- Pre-process, collate and scale the data to finally make a `tf.data.Dataset`
25
- Prepare class weights in order to tackle major imbalances
26
- Create a Conv1D and Dense-based model to perform classification
27
- Define callbacks and hyperparameters
28
- Train the model
29
- Plot metrics from History and perform evaluation
30
31
This example needs the following external dependencies (Gdown, Scikit-learn, Pandas,
32
Numpy, Matplotlib). You can install it via the following commands.
33
34
Gdown is an external package used to download large files from Google Drive. To know
35
more, you can refer to its [PyPi page here](https://pypi.org/project/gdown)
36
"""
37
38
39
"""
40
## Setup and Data Downloads
41
42
First, lets install our dependencies:
43
"""
44
45
"""shell
46
pip install gdown -q
47
pip install scikit-learn -q
48
pip install pandas -q
49
pip install numpy -q
50
pip install matplotlib -q
51
"""
52
53
"""
54
Next, lets download our dataset.
55
The gdown package makes it easy to download the data from Google Drive:
56
"""
57
58
"""shell
59
gdown 1V5B7Bt6aJm0UHbR7cRKBEK8jx7lYPVuX
60
# gdown will download eeg-data.csv onto the local drive for use. Total size of
61
# eeg-data.csv is 105.7 MB
62
"""
63
64
import pandas as pd
65
import matplotlib.pyplot as plt
66
import json
67
import numpy as np
68
import keras
69
from keras import layers
70
import tensorflow as tf
71
from sklearn import preprocessing, model_selection
72
import random
73
74
QUALITY_THRESHOLD = 128
75
BATCH_SIZE = 64
76
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 2
77
78
"""
79
## Read data from `eeg-data.csv`
80
81
We use the Pandas library to read the `eeg-data.csv` file and display the first 5 rows
82
using the `.head()` command
83
"""
84
85
eeg = pd.read_csv("eeg-data.csv")
86
87
"""
88
We remove unlabeled samples from our dataset as they do not contribute to the model. We
89
also perform a `.drop()` operation on the columns that are not required for training data
90
preparation
91
"""
92
93
unlabeled_eeg = eeg[eeg["label"] == "unlabeled"]
94
eeg = eeg.loc[eeg["label"] != "unlabeled"]
95
eeg = eeg.loc[eeg["label"] != "everyone paired"]
96
97
eeg.drop(
98
[
99
"indra_time",
100
"Unnamed: 0",
101
"browser_latency",
102
"reading_time",
103
"attention_esense",
104
"meditation_esense",
105
"updatedAt",
106
"createdAt",
107
],
108
axis=1,
109
inplace=True,
110
)
111
112
eeg.reset_index(drop=True, inplace=True)
113
eeg.head()
114
115
"""
116
In the data, the samples recorded are given a score from 0 to 128 based on how
117
well-calibrated the sensor was (0 being best, 200 being worst). We filter the values
118
based on an arbitrary cutoff limit of 128.
119
"""
120
121
122
def convert_string_data_to_values(value_string):
123
str_list = json.loads(value_string)
124
return str_list
125
126
127
eeg["raw_values"] = eeg["raw_values"].apply(convert_string_data_to_values)
128
129
eeg = eeg.loc[eeg["signal_quality"] < QUALITY_THRESHOLD]
130
eeg.head()
131
132
"""
133
## Visualize one random sample from the data
134
"""
135
136
"""
137
We visualize one sample from the data to understand how the stimulus-induced signal looks
138
like
139
"""
140
141
142
def view_eeg_plot(idx):
143
data = eeg.loc[idx, "raw_values"]
144
plt.plot(data)
145
plt.title(f"Sample random plot")
146
plt.show()
147
148
149
view_eeg_plot(7)
150
151
"""
152
## Pre-process and collate data
153
"""
154
155
"""
156
There are a total of 67 different labels present in the data, where there are numbered
157
sub-labels. We collate them under a single label as per their numbering and replace them
158
in the data itself. Following this process, we perform simple Label encoding to get them
159
in an integer format.
160
"""
161
162
print("Before replacing labels")
163
print(eeg["label"].unique(), "\n")
164
print(len(eeg["label"].unique()), "\n")
165
166
167
eeg.replace(
168
{
169
"label": {
170
"blink1": "blink",
171
"blink2": "blink",
172
"blink3": "blink",
173
"blink4": "blink",
174
"blink5": "blink",
175
"math1": "math",
176
"math2": "math",
177
"math3": "math",
178
"math4": "math",
179
"math5": "math",
180
"math6": "math",
181
"math7": "math",
182
"math8": "math",
183
"math9": "math",
184
"math10": "math",
185
"math11": "math",
186
"math12": "math",
187
"thinkOfItems-ver1": "thinkOfItems",
188
"thinkOfItems-ver2": "thinkOfItems",
189
"video-ver1": "video",
190
"video-ver2": "video",
191
"thinkOfItemsInstruction-ver1": "thinkOfItemsInstruction",
192
"thinkOfItemsInstruction-ver2": "thinkOfItemsInstruction",
193
"colorRound1-1": "colorRound1",
194
"colorRound1-2": "colorRound1",
195
"colorRound1-3": "colorRound1",
196
"colorRound1-4": "colorRound1",
197
"colorRound1-5": "colorRound1",
198
"colorRound1-6": "colorRound1",
199
"colorRound2-1": "colorRound2",
200
"colorRound2-2": "colorRound2",
201
"colorRound2-3": "colorRound2",
202
"colorRound2-4": "colorRound2",
203
"colorRound2-5": "colorRound2",
204
"colorRound2-6": "colorRound2",
205
"colorRound3-1": "colorRound3",
206
"colorRound3-2": "colorRound3",
207
"colorRound3-3": "colorRound3",
208
"colorRound3-4": "colorRound3",
209
"colorRound3-5": "colorRound3",
210
"colorRound3-6": "colorRound3",
211
"colorRound4-1": "colorRound4",
212
"colorRound4-2": "colorRound4",
213
"colorRound4-3": "colorRound4",
214
"colorRound4-4": "colorRound4",
215
"colorRound4-5": "colorRound4",
216
"colorRound4-6": "colorRound4",
217
"colorRound5-1": "colorRound5",
218
"colorRound5-2": "colorRound5",
219
"colorRound5-3": "colorRound5",
220
"colorRound5-4": "colorRound5",
221
"colorRound5-5": "colorRound5",
222
"colorRound5-6": "colorRound5",
223
"colorInstruction1": "colorInstruction",
224
"colorInstruction2": "colorInstruction",
225
"readyRound1": "readyRound",
226
"readyRound2": "readyRound",
227
"readyRound3": "readyRound",
228
"readyRound4": "readyRound",
229
"readyRound5": "readyRound",
230
"colorRound1": "colorRound",
231
"colorRound2": "colorRound",
232
"colorRound3": "colorRound",
233
"colorRound4": "colorRound",
234
"colorRound5": "colorRound",
235
}
236
},
237
inplace=True,
238
)
239
240
print("After replacing labels")
241
print(eeg["label"].unique())
242
print(len(eeg["label"].unique()))
243
244
le = preprocessing.LabelEncoder() # Generates a look-up table
245
le.fit(eeg["label"])
246
eeg["label"] = le.transform(eeg["label"])
247
248
"""
249
We extract the number of unique classes present in the data
250
"""
251
252
num_classes = len(eeg["label"].unique())
253
print(num_classes)
254
255
"""
256
We now visualize the number of samples present in each class using a Bar plot.
257
"""
258
259
plt.bar(range(num_classes), eeg["label"].value_counts())
260
plt.title("Number of samples per class")
261
plt.show()
262
263
"""
264
## Scale and split data
265
"""
266
267
"""
268
We perform a simple Min-Max scaling to bring the value-range between 0 and 1. We do not
269
use Standard Scaling as the data does not follow a Gaussian distribution.
270
"""
271
272
scaler = preprocessing.MinMaxScaler()
273
series_list = [
274
scaler.fit_transform(np.asarray(i).reshape(-1, 1)) for i in eeg["raw_values"]
275
]
276
277
labels_list = [i for i in eeg["label"]]
278
279
"""
280
We now create a Train-test split with a 15% holdout set. Following this, we reshape the
281
data to create a sequence of length 512. We also convert the labels from their current
282
label-encoded form to a one-hot encoding to enable use of several different
283
`keras.metrics` functions.
284
"""
285
286
x_train, x_test, y_train, y_test = model_selection.train_test_split(
287
series_list, labels_list, test_size=0.15, random_state=42, shuffle=True
288
)
289
290
print(
291
f"Length of x_train : {len(x_train)}\nLength of x_test : {len(x_test)}\nLength of y_train : {len(y_train)}\nLength of y_test : {len(y_test)}"
292
)
293
294
x_train = np.asarray(x_train).astype(np.float32).reshape(-1, 512, 1)
295
y_train = np.asarray(y_train).astype(np.float32).reshape(-1, 1)
296
y_train = keras.utils.to_categorical(y_train)
297
298
x_test = np.asarray(x_test).astype(np.float32).reshape(-1, 512, 1)
299
y_test = np.asarray(y_test).astype(np.float32).reshape(-1, 1)
300
y_test = keras.utils.to_categorical(y_test)
301
302
"""
303
## Prepare `tf.data.Dataset`
304
"""
305
306
"""
307
We now create a `tf.data.Dataset` from this data to prepare it for training. We also
308
shuffle and batch the data for use later.
309
"""
310
311
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
312
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
313
314
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
315
test_dataset = test_dataset.batch(BATCH_SIZE)
316
317
"""
318
## Make Class Weights using Naive method
319
"""
320
321
"""
322
As we can see from the plot of number of samples per class, the dataset is imbalanced.
323
Hence, we **calculate weights for each class** to make sure that the model is trained in
324
a fair manner without preference to any specific class due to greater number of samples.
325
326
We use a naive method to calculate these weights, finding an **inverse proportion** of
327
each class and using that as the weight.
328
"""
329
330
vals_dict = {}
331
for i in eeg["label"]:
332
if i in vals_dict.keys():
333
vals_dict[i] += 1
334
else:
335
vals_dict[i] = 1
336
total = sum(vals_dict.values())
337
338
# Formula used - Naive method where
339
# weight = 1 - (no. of samples present / total no. of samples)
340
# So more the samples, lower the weight
341
342
weight_dict = {k: (1 - (v / total)) for k, v in vals_dict.items()}
343
print(weight_dict)
344
345
"""
346
## Define simple function to plot all the metrics present in a `keras.callbacks.History`
347
object
348
"""
349
350
351
def plot_history_metrics(history: keras.callbacks.History):
352
total_plots = len(history.history)
353
cols = total_plots // 2
354
355
rows = total_plots // cols
356
357
if total_plots % cols != 0:
358
rows += 1
359
360
pos = range(1, total_plots + 1)
361
plt.figure(figsize=(15, 10))
362
for i, (key, value) in enumerate(history.history.items()):
363
plt.subplot(rows, cols, pos[i])
364
plt.plot(range(len(value)), value)
365
plt.title(str(key))
366
plt.show()
367
368
369
"""
370
## Define function to generate Convolutional model
371
"""
372
373
374
def create_model():
375
input_layer = keras.Input(shape=(512, 1))
376
377
x = layers.Conv1D(
378
filters=32, kernel_size=3, strides=2, activation="relu", padding="same"
379
)(input_layer)
380
x = layers.BatchNormalization()(x)
381
382
x = layers.Conv1D(
383
filters=64, kernel_size=3, strides=2, activation="relu", padding="same"
384
)(x)
385
x = layers.BatchNormalization()(x)
386
387
x = layers.Conv1D(
388
filters=128, kernel_size=5, strides=2, activation="relu", padding="same"
389
)(x)
390
x = layers.BatchNormalization()(x)
391
392
x = layers.Conv1D(
393
filters=256, kernel_size=5, strides=2, activation="relu", padding="same"
394
)(x)
395
x = layers.BatchNormalization()(x)
396
397
x = layers.Conv1D(
398
filters=512, kernel_size=7, strides=2, activation="relu", padding="same"
399
)(x)
400
x = layers.BatchNormalization()(x)
401
402
x = layers.Conv1D(
403
filters=1024,
404
kernel_size=7,
405
strides=2,
406
activation="relu",
407
padding="same",
408
)(x)
409
x = layers.BatchNormalization()(x)
410
411
x = layers.Dropout(0.2)(x)
412
413
x = layers.Flatten()(x)
414
415
x = layers.Dense(4096, activation="relu")(x)
416
x = layers.Dropout(0.2)(x)
417
418
x = layers.Dense(
419
2048, activation="relu", kernel_regularizer=keras.regularizers.L2()
420
)(x)
421
x = layers.Dropout(0.2)(x)
422
423
x = layers.Dense(
424
1024, activation="relu", kernel_regularizer=keras.regularizers.L2()
425
)(x)
426
x = layers.Dropout(0.2)(x)
427
x = layers.Dense(
428
128, activation="relu", kernel_regularizer=keras.regularizers.L2()
429
)(x)
430
output_layer = layers.Dense(num_classes, activation="softmax")(x)
431
432
return keras.Model(inputs=input_layer, outputs=output_layer)
433
434
435
"""
436
## Get Model summary
437
"""
438
439
conv_model = create_model()
440
conv_model.summary()
441
442
"""
443
## Define callbacks, optimizer, loss and metrics
444
"""
445
446
"""
447
We set the number of epochs at 30 after performing extensive experimentation. It was seen
448
that this was the optimal number, after performing Early-Stopping analysis as well.
449
We define a Model Checkpoint callback to make sure that we only get the best model
450
weights.
451
We also define a ReduceLROnPlateau as there were several cases found during
452
experimentation where the loss stagnated after a certain point. On the other hand, a
453
direct LRScheduler was found to be too aggressive in its decay.
454
"""
455
456
epochs = 30
457
458
callbacks = [
459
keras.callbacks.ModelCheckpoint(
460
"best_model.keras", save_best_only=True, monitor="loss"
461
),
462
keras.callbacks.ReduceLROnPlateau(
463
monitor="val_top_k_categorical_accuracy",
464
factor=0.2,
465
patience=2,
466
min_lr=0.000001,
467
),
468
]
469
470
optimizer = keras.optimizers.Adam(amsgrad=True, learning_rate=0.001)
471
loss = keras.losses.CategoricalCrossentropy()
472
473
"""
474
## Compile model and call `model.fit()`
475
"""
476
477
"""
478
We use the `Adam` optimizer since it is commonly considered the best choice for
479
preliminary training, and was found to be the best optimizer.
480
We use `CategoricalCrossentropy` as the loss as our labels are in a one-hot-encoded form.
481
482
We define the `TopKCategoricalAccuracy(k=3)`, `AUC`, `Precision` and `Recall` metrics to
483
further aid in understanding the model better.
484
"""
485
486
conv_model.compile(
487
optimizer=optimizer,
488
loss=loss,
489
metrics=[
490
keras.metrics.TopKCategoricalAccuracy(k=3),
491
keras.metrics.AUC(),
492
keras.metrics.Precision(),
493
keras.metrics.Recall(),
494
],
495
)
496
497
conv_model_history = conv_model.fit(
498
train_dataset,
499
epochs=epochs,
500
callbacks=callbacks,
501
validation_data=test_dataset,
502
class_weight=weight_dict,
503
)
504
505
"""
506
## Visualize model metrics during training
507
"""
508
509
"""
510
We use the function defined above to see model metrics during training.
511
"""
512
513
plot_history_metrics(conv_model_history)
514
515
"""
516
## Evaluate model on test data
517
"""
518
519
loss, accuracy, auc, precision, recall = conv_model.evaluate(test_dataset)
520
print(f"Loss : {loss}")
521
print(f"Top 3 Categorical Accuracy : {accuracy}")
522
print(f"Area under the Curve (ROC) : {auc}")
523
print(f"Precision : {precision}")
524
print(f"Recall : {recall}")
525
526
527
def view_evaluated_eeg_plots(model):
528
start_index = random.randint(10, len(eeg))
529
end_index = start_index + 11
530
data = eeg.loc[start_index:end_index, "raw_values"]
531
data_array = [scaler.fit_transform(np.asarray(i).reshape(-1, 1)) for i in data]
532
data_array = [np.asarray(data_array).astype(np.float32).reshape(-1, 512, 1)]
533
original_labels = eeg.loc[start_index:end_index, "label"]
534
predicted_labels = np.argmax(model.predict(data_array, verbose=0), axis=1)
535
original_labels = [
536
le.inverse_transform(np.array(label).reshape(-1))[0]
537
for label in original_labels
538
]
539
predicted_labels = [
540
le.inverse_transform(np.array(label).reshape(-1))[0]
541
for label in predicted_labels
542
]
543
total_plots = 12
544
cols = total_plots // 3
545
rows = total_plots // cols
546
if total_plots % cols != 0:
547
rows += 1
548
pos = range(1, total_plots + 1)
549
fig = plt.figure(figsize=(20, 10))
550
for i, (plot_data, og_label, pred_label) in enumerate(
551
zip(data, original_labels, predicted_labels)
552
):
553
plt.subplot(rows, cols, pos[i])
554
plt.plot(plot_data)
555
plt.title(f"Actual Label : {og_label}\nPredicted Label : {pred_label}")
556
fig.subplots_adjust(hspace=0.5)
557
plt.show()
558
559
560
view_evaluated_eeg_plots(conv_model)
561
562