Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/timeseries/eeg_bci_ssvepformer.py
3507 views
1
"""
2
Title: Electroencephalogram Signal Classification for Brain-Computer Interface
3
Author: [Okba Bekhelifi](https://github.com/okbalefthanded)
4
Date created: 2025/01/08
5
Last modified: 2025/01/08
6
Description: A Transformer based classification for EEG signal for BCI.
7
Accelerator: GPU
8
"""
9
10
"""
11
# Introduction
12
13
This tutorial will explain how to build a Transformer based Neural Network to classify
14
Brain-Computer Interface (BCI) Electroencephalograpy (EEG) data recorded in a
15
Steady-State Visual Evoked Potentials (SSVEPs) experiment for the application of a
16
brain-controlled speller.
17
18
The tutorial reproduces an experiment from the SSVEPFormer study [1]
19
( [arXiv preprint](https://arxiv.org/abs/2210.04172) /
20
[Peer-Reviewed paper](https://www.sciencedirect.com/science/article/abs/pii/S0893608023002319) ).
21
This model was the first Transformer based model to be introduced for SSVEP data classification,
22
we will test it on the Nakanishi et al. [2] public dataset as dataset 1 from the paper.
23
24
The process follows an inter-subject classification experiment. Given N subject data in
25
the dataset, the training data partition contains data from N-1 subject and the remaining
26
single subject data is used for testing. the training set does not contain any sample from
27
the testing subject. This way we construct a true subject-independent model. We keep the
28
same parameters and settings as the original paper in all processing operations from
29
preprocessing to training.
30
31
32
The tutorial begins with a quick BCI and dataset description then, we go through the
33
technicalities following these sections:
34
- Setup, and imports.
35
- Dataset download and extraction.
36
- Data preprocessing: EEG data filtering, segmentation and visualization of raw and
37
filtered data, and frequency response for a well performing participant.
38
- Layers and model creation.
39
- Evaluation: a single participant data classification as an example then the total
40
participants data classification.
41
- Visulization: we show the results of training and inference times comparison among
42
the Keras 3 available backends (JAX, Tensorflow, and PyTorch) on three different GPUs.
43
- Conclusion: final discussion and remarks.
44
45
"""
46
47
"""
48
# Dataset description
49
50
## BCI and SSVEP:
51
A BCI offers the ability to communicate using only brain activity, this can be achieved
52
through exogenous stimuli that generate specific responses indicating the intent of the
53
subject. the responses are elicited when the user focuses their attention on the target
54
stimulus. We can use visual stimuli by presenting the subject with a set of options
55
typically on a monitor as a grid to select one command at a time. Each stimulus will
56
flicker following a fixed frequency and phase, the resulting EEG recorded at occipital
57
and occipito-parietal areas of the cortex (visual cortex) will have higher power in the
58
associated frequency with the stimulus where the subject was looking at. This type of
59
BCI paradigm is called the Steady-State Visual Evoked Potentials (SSVEPs) and became
60
widely used for multiple application due to its reliability and high perfromance in
61
classification and rapidity as a 1-second of EEG is sufficient making a command. Other
62
types of brain responses exists and do not require external stimulations, however they
63
are less reliable.
64
[Demo video](https://www.youtube.com/watch?v=VtA6jsEMIug)
65
66
This tutorials uses the 12 commands (class) public SSVEP dataset [2] with the following
67
interface emulating a phone dialing numbers.
68
![dataset](/img/eeg_bci_ssvepformer/eeg_ssvepformer_dataset1_interface.jpg)
69
70
The dataset was recorded with 10 participants, each faced the above 12 SSVEP stimuli (A).
71
The stimulation frequencies ranged from 9.25Hz to 14.75 Hz with 0.5Hz step, and phases
72
ranged from 0 to 1.5 π with 0.5 π step for each row.(B). The EEG signal was acquired
73
with 8 electrodes (channels) (PO7, PO3, POz,
74
PO4, PO8, O1, Oz, O2) sampling frequency was 2048 Hz then the stored data were
75
downsampled to 256 Hz. The subjects completed 15 blocks of recordings, each consisted
76
of 12 random ordered stimulations (1 for each class) of 4 seconds each. In total,
77
each subject conducted 180 trials.
78
79
80
"""
81
82
"""
83
# Setup
84
"""
85
86
"""
87
## Select JAX backend
88
89
"""
90
91
import os
92
93
os.environ["KERAS_BACKEND"] = "jax"
94
95
"""
96
## Install dependencies
97
98
"""
99
100
"""shell
101
pip install -q numpy
102
pip install -q scipy
103
pip install -q matplotlib
104
"""
105
106
"""
107
# Imports
108
109
110
"""
111
112
# deep learning libraries
113
from keras import backend as K
114
from keras import layers
115
import keras
116
117
# visualization and signal processing imports
118
import matplotlib.pyplot as plt
119
import tensorflow as tf
120
import numpy as np
121
from scipy.signal import butter, filtfilt
122
from scipy.io import loadmat
123
124
# setting the backend, seed and Keras channel format
125
K.set_image_data_format("channels_first")
126
keras.utils.set_random_seed(42)
127
128
"""
129
# Download and extract dataset
130
131
132
"""
133
134
"""
135
## Nakanishi et. al 2015 [DataSet Repo](https://github.com/mnakanishi/12JFPM_SSVEP)
136
"""
137
138
"""shell
139
curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip
140
unzip cca_ssvep.zip
141
"""
142
143
"""
144
# Pre-Processing
145
146
The preprocessing steps followed are first to read the EEG data for each subject, then
147
to filter the raw data in a frequency interval where most useful information lies,
148
then we select a fixed duration of signal starting from the onset of the stimulation
149
(due to latency delay caused by the visual system we start we add 135 milliseconds to
150
the stimulation onset). Lastly, all subjects data are concatenated in a single Tensor
151
of the shape: [subjects x samples x channels x trials]. The data labels are also
152
concatenated following the order of the trials in the experiments and will be a
153
matrix of the shape [subjects x trials]
154
(here by channels we mean electrodes, we use this notation throughout the tutorial).
155
"""
156
157
158
def raw_signal(folder, fs=256, duration=1.0, onset=0.135):
159
"""selecting a 1-second segment of the raw EEG signal for
160
subject 1.
161
"""
162
onset = 38 + int(onset * fs)
163
end = int(duration * fs)
164
data = loadmat(f"{folder}/s1.mat")
165
# samples, channels, trials, targets
166
eeg = data["eeg"].transpose((2, 1, 3, 0))
167
# segment data
168
eeg = eeg[onset : onset + end, :, :, :]
169
return eeg
170
171
172
def segment_eeg(
173
folder, elecs=None, fs=256, duration=1.0, band=[5.0, 45.0], order=4, onset=0.135
174
):
175
"""Filtering and segmenting EEG signals for all subjects."""
176
n_subejects = 10
177
onset = 38 + int(onset * fs)
178
end = int(duration * fs)
179
X, Y = [], [] # empty data and labels
180
181
for subj in range(1, n_subejects + 1):
182
data = loadmat(f"{data_folder}/s{subj}.mat")
183
# samples, channels, trials, targets
184
eeg = data["eeg"].transpose((2, 1, 3, 0))
185
# filter data
186
eeg = filter_eeg(eeg, fs=fs, band=band, order=order)
187
# segment data
188
eeg = eeg[onset : onset + end, :, :, :]
189
# reshape labels
190
samples, channels, blocks, targets = eeg.shape
191
y = np.tile(np.arange(1, targets + 1), (blocks, 1))
192
y = y.reshape((1, blocks * targets), order="F")
193
194
X.append(eeg.reshape((samples, channels, blocks * targets), order="F"))
195
Y.append(y)
196
197
X = np.array(X, dtype=np.float32, order="F")
198
Y = np.array(Y, dtype=np.float32).squeeze()
199
200
return X, Y
201
202
203
def filter_eeg(data, fs=256, band=[5.0, 45.0], order=4):
204
"""Filter EEG signal using a zero-phase IIR filter"""
205
B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")
206
return filtfilt(B, A, data, axis=0)
207
208
209
"""
210
## Segment data into epochs
211
"""
212
213
data_folder = os.path.abspath("./cca_ssvep")
214
band = [8, 64] # low-frequency / high-frequency cutoffS
215
order = 4 # filter order
216
fs = 256 # sampling frequency
217
duration = 1.0 # 1 second
218
219
# raw signal
220
X_raw = raw_signal(data_folder, fs=fs, duration=duration)
221
print(
222
f"A single subject raw EEG (X_raw) shape: {X_raw.shape} [Samples x Channels x Blocks x Targets]"
223
)
224
225
# segmented signal
226
X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)
227
print(
228
f"Full training data (X) shape: {X.shape} [Subject x Samples x Channels x Trials]"
229
)
230
print(f"data labels (Y) shape: {Y.shape} [Subject x Trials]")
231
232
samples = X.shape[1]
233
time = np.linspace(0.0, samples / fs, samples) * 1000
234
235
"""
236
## Visualize EEG signal
237
"""
238
239
"""
240
## EEG in time
241
242
Raw EEG vs Filtered EEG
243
The same 1-second recording for subject s1 at Oz (central electrode in the visual cortex,
244
back of the head) is illustrated. left is the raw EEG as recorded and in the right is
245
the filtered EEG on the [8, 64] Hz frequency band. we see less noise and
246
normalized amplitude values in a natural EEG range.
247
"""
248
249
250
elec = 6 # Oz channel
251
252
x_label = "Time (ms)"
253
y_label = "Voltage (uV)"
254
# Create subplots
255
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
256
257
# Plot data on the first subplot
258
ax1.plot(time, X_raw[:, elec, 0, 0], "r-")
259
ax1.set_xlabel(x_label)
260
ax1.set_ylabel(y_label)
261
ax1.set_title("Raw EEG : 1 second at Oz ")
262
263
# Plot data on the second subplot
264
ax2.plot(time, X[0, :, elec, 0], "b-")
265
ax2.set_xlabel(x_label)
266
ax2.set_ylabel(y_label)
267
ax2.set_title("Filtered EEG between 8-64 Hz: 1 second at Oz")
268
269
# Adjust spacing between subplots
270
plt.tight_layout()
271
272
# Show the plot
273
plt.show()
274
275
"""
276
## EEG frequency representation
277
278
Using the welch method, we visualize the frequency power for a well performing subject
279
for the entire 4 seconds EEG recording at Oz electrode for each stimuli. the red peaks
280
indicate the stimuli fundamental frequency and the 2nd harmonics (double the fundamental
281
frequency). we see clear peaks showing the high responses from that subject which means
282
that this subject is a good candidate for SSVEP BCI control. In many cases the peaks
283
are weak or absent, meaning that subject do not achieve the task correctly.
284
285
![eeg_frequency](/img/eeg_bci_ssvepformer/eeg_ssvepformer_frequencypowers.png)
286
"""
287
288
289
"""
290
# Create Layers and model
291
292
Create Layers in a cross-framework custom component fashion.
293
In the SSVEPFormer, the data is first transformed to the frequency domain through
294
Fast-Fourier transform (FFT), to construct a complex spectrum presentation consisting of
295
the concatenation of frequency and phase information in a fixed frequency band. To keep
296
the model in an end-to-end format, we implement the complex spectrum transformation as
297
non-trainable layer.
298
299
![model](/img/eeg_bci_ssvepformer/eeg_ssvepformer_model.jpg)
300
The SSVEPFormer unlike the Transformer architecture does not contain positional encoding/embedding
301
layers which replaced a channel combination block that has a layer of Conv1D layer of 1
302
kernel size with double input channels (double the count of electrodes) number of filters,
303
and LayerNorm, Gelu activation and dropout.
304
Another difference with Transformers is the absence of multi-head attention layers with
305
attention mechanism.
306
The model encoder contains two identical and successive blocks. Each block has two
307
sub-blocks of CNN module and MLP module. the CNN module consists of a LayerNorm, Conv1D
308
with the same number of filters as channel combination, LayerNorm, Gelu, Dropout and an
309
residual connection. The MLP module consists of a LayerNorm, Dense layer, Gelu, droput
310
and residual connection. the Dense layer is applied on each channel separately.
311
The last block of the model is MLP head with Flatten layer, Dropout, Dense, LayerNorm,
312
Gelu, Dropout and Dense layer with softmax acitvation.
313
All trainable weights are initialized by a normal distribution with 0 mean and 0.01
314
standard deviation as state in the original paper.
315
"""
316
317
318
class ComplexSpectrum(keras.layers.Layer):
319
def __init__(self, nfft=512, fft_start=8, fft_end=64):
320
super().__init__()
321
self.nfft = nfft
322
self.fft_start = fft_start
323
self.fft_end = fft_end
324
325
def call(self, x):
326
samples = x.shape[-1]
327
x = keras.ops.rfft(x, fft_length=self.nfft)
328
real = x[0] / samples
329
imag = x[1] / samples
330
real = real[:, :, self.fft_start : self.fft_end]
331
imag = imag[:, :, self.fft_start : self.fft_end]
332
x = keras.ops.concatenate((real, imag), axis=-1)
333
return x
334
335
336
class ChannelComb(keras.layers.Layer):
337
def __init__(self, n_channels, drop_rate=0.5):
338
super().__init__()
339
self.conv = layers.Conv1D(
340
2 * n_channels,
341
1,
342
padding="same",
343
kernel_initializer=keras.initializers.RandomNormal(
344
mean=0.0, stddev=0.01, seed=None
345
),
346
)
347
self.normalization = layers.LayerNormalization()
348
self.activation = layers.Activation(activation="gelu")
349
self.drop = layers.Dropout(drop_rate)
350
351
def call(self, x):
352
x = self.conv(x)
353
x = self.normalization(x)
354
x = self.activation(x)
355
x = self.drop(x)
356
return x
357
358
359
class ConvAttention(keras.layers.Layer):
360
def __init__(self, n_channels, drop_rate=0.5):
361
super().__init__()
362
self.norm = layers.LayerNormalization()
363
self.conv = layers.Conv1D(
364
2 * n_channels,
365
31,
366
padding="same",
367
kernel_initializer=keras.initializers.RandomNormal(
368
mean=0.0, stddev=0.01, seed=None
369
),
370
)
371
self.activation = layers.Activation(activation="gelu")
372
self.drop = layers.Dropout(drop_rate)
373
374
def call(self, x):
375
input = x
376
x = self.norm(x)
377
x = self.conv(x)
378
x = self.activation(x)
379
x = self.drop(x)
380
x = x + input
381
return x
382
383
384
class ChannelMLP(keras.layers.Layer):
385
def __init__(self, n_features, drop_rate=0.5):
386
super().__init__()
387
self.norm = layers.LayerNormalization()
388
self.mlp = layers.Dense(
389
2 * n_features,
390
kernel_initializer=keras.initializers.RandomNormal(
391
mean=0.0, stddev=0.01, seed=None
392
),
393
)
394
self.activation = layers.Activation(activation="gelu")
395
self.drop = layers.Dropout(drop_rate)
396
self.cat = layers.Concatenate(axis=1)
397
398
def call(self, x):
399
input = x
400
channels = x.shape[1] # x shape : NCF
401
x = self.norm(x)
402
output_channels = []
403
for i in range(channels):
404
c = self.mlp(x[:, :, i])
405
c = layers.Reshape([1, -1])(c)
406
output_channels.append(c)
407
x = self.cat(output_channels)
408
x = self.activation(x)
409
x = self.drop(x)
410
x = x + input
411
return x
412
413
414
class Encoder(keras.layers.Layer):
415
def __init__(self, n_channels, n_features, drop_rate=0.5):
416
super().__init__()
417
self.attention1 = ConvAttention(n_channels, drop_rate=drop_rate)
418
self.mlp1 = ChannelMLP(n_features, drop_rate=drop_rate)
419
self.attention2 = ConvAttention(n_channels, drop_rate=drop_rate)
420
self.mlp2 = ChannelMLP(n_features, drop_rate=drop_rate)
421
422
def call(self, x):
423
x = self.attention1(x)
424
x = self.mlp1(x)
425
x = self.attention2(x)
426
x = self.mlp2(x)
427
return x
428
429
430
class MlpHead(keras.layers.Layer):
431
def __init__(self, n_classes, drop_rate=0.5):
432
super().__init__()
433
self.flatten = layers.Flatten()
434
self.drop = layers.Dropout(drop_rate)
435
self.linear1 = layers.Dense(
436
6 * n_classes,
437
kernel_initializer=keras.initializers.RandomNormal(
438
mean=0.0, stddev=0.01, seed=None
439
),
440
)
441
self.norm = layers.LayerNormalization()
442
self.activation = layers.Activation(activation="gelu")
443
self.drop2 = layers.Dropout(drop_rate)
444
self.linear2 = layers.Dense(
445
n_classes,
446
kernel_initializer=keras.initializers.RandomNormal(
447
mean=0.0, stddev=0.01, seed=None
448
),
449
)
450
451
def call(self, x):
452
x = self.flatten(x)
453
x = self.drop(x)
454
x = self.linear1(x)
455
x = self.norm(x)
456
x = self.activation(x)
457
x = self.drop2(x)
458
x = self.linear2(x)
459
return x
460
461
462
"""
463
### Create a sequential model with the layers above
464
"""
465
466
467
def create_ssvepformer(
468
input_shape, fs, resolution, fq_band, n_channels, n_classes, drop_rate
469
):
470
nfft = round(fs / resolution)
471
fft_start = int(fq_band[0] / resolution)
472
fft_end = int(fq_band[1] / resolution) + 1
473
n_features = fft_end - fft_start
474
475
model = keras.Sequential(
476
[
477
keras.Input(shape=input_shape),
478
ComplexSpectrum(nfft, fft_start, fft_end),
479
ChannelComb(n_channels=n_channels, drop_rate=drop_rate),
480
Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
481
Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
482
MlpHead(n_classes=n_classes, drop_rate=drop_rate),
483
layers.Activation(activation="softmax"),
484
]
485
)
486
487
return model
488
489
490
"""
491
# Evaluation
492
"""
493
494
# Training settings same as the original paper
495
BATCH_SIZE = 128
496
EPOCHS = 100
497
LR = 0.001 # learning rate
498
WD = 0.001 # weight decay
499
MOMENTUM = 0.9
500
DROP_RATE = 0.5
501
502
resolution = 0.25
503
504
"""
505
From the entire dataset we select folds for each subject evaluation.
506
construct a tf dataset object for train and testing data and create the model and launch
507
the training using SGD optimizer.
508
"""
509
510
511
def concatenate_subjects(x, y, fold):
512
X = np.concatenate([x[idx] for idx in fold], axis=-1)
513
Y = np.concatenate([y[idx] for idx in fold], axis=-1)
514
X = X.transpose((2, 1, 0)) # trials x channels x samples
515
return X, Y - 1 # transform labels to values from 0...11
516
517
518
def evaluate_subject(
519
x_train,
520
y_train,
521
x_val,
522
y_val,
523
input_shape,
524
fs=256,
525
resolution=0.25,
526
band=[8, 64],
527
channels=8,
528
n_classes=12,
529
drop_rate=DROP_RATE,
530
):
531
532
train_dataset = (
533
tf.data.Dataset.from_tensor_slices((x_train, y_train))
534
.batch(BATCH_SIZE)
535
.prefetch(tf.data.AUTOTUNE)
536
)
537
538
test_dataset = (
539
tf.data.Dataset.from_tensor_slices((x_val, y_val))
540
.batch(BATCH_SIZE)
541
.prefetch(tf.data.AUTOTUNE)
542
)
543
544
model = create_ssvepformer(
545
input_shape, fs, resolution, band, channels, n_classes, drop_rate
546
)
547
sgd = keras.optimizers.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WD)
548
549
model.compile(
550
loss="sparse_categorical_crossentropy",
551
optimizer=sgd,
552
metrics=["accuracy"],
553
jit_compile=True,
554
)
555
556
history = model.fit(
557
train_dataset,
558
batch_size=BATCH_SIZE,
559
epochs=EPOCHS,
560
validation_data=test_dataset,
561
verbose=0,
562
)
563
loss, acc = model.evaluate(test_dataset)
564
return acc * 100
565
566
567
"""
568
## Run evaluation
569
"""
570
571
channels = X.shape[2]
572
samples = X.shape[1]
573
input_shape = (channels, samples)
574
n_classes = 12
575
576
model = create_ssvepformer(
577
input_shape, fs, resolution, band, channels, n_classes, DROP_RATE
578
)
579
model.summary()
580
581
"""
582
## Evaluation on all subjects following a leave-one-subject out data repartition scheme
583
"""
584
585
accs = np.zeros(10)
586
587
for subject in range(10):
588
print(f"Testing subject: {subject+ 1}")
589
590
# create train / test folds
591
folds = np.delete(np.arange(10), subject)
592
train_index = folds
593
test_index = [subject]
594
595
# create data split for each subject
596
x_train, y_train = concatenate_subjects(X, Y, train_index)
597
x_val, y_val = concatenate_subjects(X, Y, test_index)
598
599
# train and evaluate a fold and compute the time it takes
600
acc = evaluate_subject(x_train, y_train, x_val, y_val, input_shape)
601
602
accs[subject] = acc
603
604
print(f"\nAccuracy Across Subjects: {accs.mean()} % std: {np.std(accs)}")
605
606
"""
607
and that's it! we see how some subjects with no data on the training set still can achieve
608
almost a 100% correct commands and others show poor performance around 50%. In the original
609
paper using PyTorch the average accuracy was 84.04% with 17.37 std. we reached the same
610
values knowing the stochastic nature of deep learning.
611
"""
612
613
"""
614
# Visualizations
615
616
Training and inference times comparison between the different backends (Jax, Tensorflow
617
and PyTorch) on the three GPUs available with Colab Free/Pro/Pro+: T4, L4, A100.
618
619
620
"""
621
622
"""
623
## Training Time
624
625
![training_time](/img/eeg_bci_ssvepformer/eeg_ssvepformer_keras_training_time.png)
626
"""
627
628
"""
629
# Inference Time
630
631
![inference_time](/img/eeg_bci_ssvepformer/eeg_ssvepformer_keras_inference_time.png)
632
"""
633
634
"""
635
the Jax backend was the best on training and inference in all the GPUs, the PyTorch was
636
exremely slow due to the jit compilation option being disable because of the complex
637
data type calculated by FFT which is not supported by the PyTorch jit compiler.
638
"""
639
640
"""
641
# Acknowledgment
642
643
I thank Chris Perry [X](https://x.com/thechrisperry) @GoogleColab for supporting this
644
work with GPU compute.
645
"""
646
647
"""
648
# References
649
[1] Chen, J. et al. (2023) ‘A transformer-based deep neural network model for SSVEP
650
classification’, Neural Networks, 164, pp. 521–534. Available at: https://doi.org/10.1016/j.neunet.2023.04.045.
651
652
[2] Nakanishi, M. et al. (2015) ‘A Comparison Study of Canonical Correlation Analysis
653
Based Methods for Detecting Steady-State Visual Evoked Potentials’, Plos One, 10(10), p.
654
e0140703. Available at: https://doi.org/10.1371/journal.pone.0140703
655
"""
656
657