Path: blob/master/examples/timeseries/eeg_bci_ssvepformer.py
3507 views
"""1Title: Electroencephalogram Signal Classification for Brain-Computer Interface2Author: [Okba Bekhelifi](https://github.com/okbalefthanded)3Date created: 2025/01/084Last modified: 2025/01/085Description: A Transformer based classification for EEG signal for BCI.6Accelerator: GPU7"""89"""10# Introduction1112This tutorial will explain how to build a Transformer based Neural Network to classify13Brain-Computer Interface (BCI) Electroencephalograpy (EEG) data recorded in a14Steady-State Visual Evoked Potentials (SSVEPs) experiment for the application of a15brain-controlled speller.1617The tutorial reproduces an experiment from the SSVEPFormer study [1]18( [arXiv preprint](https://arxiv.org/abs/2210.04172) /19[Peer-Reviewed paper](https://www.sciencedirect.com/science/article/abs/pii/S0893608023002319) ).20This model was the first Transformer based model to be introduced for SSVEP data classification,21we will test it on the Nakanishi et al. [2] public dataset as dataset 1 from the paper.2223The process follows an inter-subject classification experiment. Given N subject data in24the dataset, the training data partition contains data from N-1 subject and the remaining25single subject data is used for testing. the training set does not contain any sample from26the testing subject. This way we construct a true subject-independent model. We keep the27same parameters and settings as the original paper in all processing operations from28preprocessing to training.293031The tutorial begins with a quick BCI and dataset description then, we go through the32technicalities following these sections:33- Setup, and imports.34- Dataset download and extraction.35- Data preprocessing: EEG data filtering, segmentation and visualization of raw and36filtered data, and frequency response for a well performing participant.37- Layers and model creation.38- Evaluation: a single participant data classification as an example then the total39participants data classification.40- Visulization: we show the results of training and inference times comparison among41the Keras 3 available backends (JAX, Tensorflow, and PyTorch) on three different GPUs.42- Conclusion: final discussion and remarks.4344"""4546"""47# Dataset description4849## BCI and SSVEP:50A BCI offers the ability to communicate using only brain activity, this can be achieved51through exogenous stimuli that generate specific responses indicating the intent of the52subject. the responses are elicited when the user focuses their attention on the target53stimulus. We can use visual stimuli by presenting the subject with a set of options54typically on a monitor as a grid to select one command at a time. Each stimulus will55flicker following a fixed frequency and phase, the resulting EEG recorded at occipital56and occipito-parietal areas of the cortex (visual cortex) will have higher power in the57associated frequency with the stimulus where the subject was looking at. This type of58BCI paradigm is called the Steady-State Visual Evoked Potentials (SSVEPs) and became59widely used for multiple application due to its reliability and high perfromance in60classification and rapidity as a 1-second of EEG is sufficient making a command. Other61types of brain responses exists and do not require external stimulations, however they62are less reliable.63[Demo video](https://www.youtube.com/watch?v=VtA6jsEMIug)6465This tutorials uses the 12 commands (class) public SSVEP dataset [2] with the following66interface emulating a phone dialing numbers.676869The dataset was recorded with 10 participants, each faced the above 12 SSVEP stimuli (A).70The stimulation frequencies ranged from 9.25Hz to 14.75 Hz with 0.5Hz step, and phases71ranged from 0 to 1.5 π with 0.5 π step for each row.(B). The EEG signal was acquired72with 8 electrodes (channels) (PO7, PO3, POz,73PO4, PO8, O1, Oz, O2) sampling frequency was 2048 Hz then the stored data were74downsampled to 256 Hz. The subjects completed 15 blocks of recordings, each consisted75of 12 random ordered stimulations (1 for each class) of 4 seconds each. In total,76each subject conducted 180 trials.777879"""8081"""82# Setup83"""8485"""86## Select JAX backend8788"""8990import os9192os.environ["KERAS_BACKEND"] = "jax"9394"""95## Install dependencies9697"""9899"""shell100pip install -q numpy101pip install -q scipy102pip install -q matplotlib103"""104105"""106# Imports107108109"""110111# deep learning libraries112from keras import backend as K113from keras import layers114import keras115116# visualization and signal processing imports117import matplotlib.pyplot as plt118import tensorflow as tf119import numpy as np120from scipy.signal import butter, filtfilt121from scipy.io import loadmat122123# setting the backend, seed and Keras channel format124K.set_image_data_format("channels_first")125keras.utils.set_random_seed(42)126127"""128# Download and extract dataset129130131"""132133"""134## Nakanishi et. al 2015 [DataSet Repo](https://github.com/mnakanishi/12JFPM_SSVEP)135"""136137"""shell138curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip139unzip cca_ssvep.zip140"""141142"""143# Pre-Processing144145The preprocessing steps followed are first to read the EEG data for each subject, then146to filter the raw data in a frequency interval where most useful information lies,147then we select a fixed duration of signal starting from the onset of the stimulation148(due to latency delay caused by the visual system we start we add 135 milliseconds to149the stimulation onset). Lastly, all subjects data are concatenated in a single Tensor150of the shape: [subjects x samples x channels x trials]. The data labels are also151concatenated following the order of the trials in the experiments and will be a152matrix of the shape [subjects x trials]153(here by channels we mean electrodes, we use this notation throughout the tutorial).154"""155156157def raw_signal(folder, fs=256, duration=1.0, onset=0.135):158"""selecting a 1-second segment of the raw EEG signal for159subject 1.160"""161onset = 38 + int(onset * fs)162end = int(duration * fs)163data = loadmat(f"{folder}/s1.mat")164# samples, channels, trials, targets165eeg = data["eeg"].transpose((2, 1, 3, 0))166# segment data167eeg = eeg[onset : onset + end, :, :, :]168return eeg169170171def segment_eeg(172folder, elecs=None, fs=256, duration=1.0, band=[5.0, 45.0], order=4, onset=0.135173):174"""Filtering and segmenting EEG signals for all subjects."""175n_subejects = 10176onset = 38 + int(onset * fs)177end = int(duration * fs)178X, Y = [], [] # empty data and labels179180for subj in range(1, n_subejects + 1):181data = loadmat(f"{data_folder}/s{subj}.mat")182# samples, channels, trials, targets183eeg = data["eeg"].transpose((2, 1, 3, 0))184# filter data185eeg = filter_eeg(eeg, fs=fs, band=band, order=order)186# segment data187eeg = eeg[onset : onset + end, :, :, :]188# reshape labels189samples, channels, blocks, targets = eeg.shape190y = np.tile(np.arange(1, targets + 1), (blocks, 1))191y = y.reshape((1, blocks * targets), order="F")192193X.append(eeg.reshape((samples, channels, blocks * targets), order="F"))194Y.append(y)195196X = np.array(X, dtype=np.float32, order="F")197Y = np.array(Y, dtype=np.float32).squeeze()198199return X, Y200201202def filter_eeg(data, fs=256, band=[5.0, 45.0], order=4):203"""Filter EEG signal using a zero-phase IIR filter"""204B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")205return filtfilt(B, A, data, axis=0)206207208"""209## Segment data into epochs210"""211212data_folder = os.path.abspath("./cca_ssvep")213band = [8, 64] # low-frequency / high-frequency cutoffS214order = 4 # filter order215fs = 256 # sampling frequency216duration = 1.0 # 1 second217218# raw signal219X_raw = raw_signal(data_folder, fs=fs, duration=duration)220print(221f"A single subject raw EEG (X_raw) shape: {X_raw.shape} [Samples x Channels x Blocks x Targets]"222)223224# segmented signal225X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)226print(227f"Full training data (X) shape: {X.shape} [Subject x Samples x Channels x Trials]"228)229print(f"data labels (Y) shape: {Y.shape} [Subject x Trials]")230231samples = X.shape[1]232time = np.linspace(0.0, samples / fs, samples) * 1000233234"""235## Visualize EEG signal236"""237238"""239## EEG in time240241Raw EEG vs Filtered EEG242The same 1-second recording for subject s1 at Oz (central electrode in the visual cortex,243back of the head) is illustrated. left is the raw EEG as recorded and in the right is244the filtered EEG on the [8, 64] Hz frequency band. we see less noise and245normalized amplitude values in a natural EEG range.246"""247248249elec = 6 # Oz channel250251x_label = "Time (ms)"252y_label = "Voltage (uV)"253# Create subplots254fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))255256# Plot data on the first subplot257ax1.plot(time, X_raw[:, elec, 0, 0], "r-")258ax1.set_xlabel(x_label)259ax1.set_ylabel(y_label)260ax1.set_title("Raw EEG : 1 second at Oz ")261262# Plot data on the second subplot263ax2.plot(time, X[0, :, elec, 0], "b-")264ax2.set_xlabel(x_label)265ax2.set_ylabel(y_label)266ax2.set_title("Filtered EEG between 8-64 Hz: 1 second at Oz")267268# Adjust spacing between subplots269plt.tight_layout()270271# Show the plot272plt.show()273274"""275## EEG frequency representation276277Using the welch method, we visualize the frequency power for a well performing subject278for the entire 4 seconds EEG recording at Oz electrode for each stimuli. the red peaks279indicate the stimuli fundamental frequency and the 2nd harmonics (double the fundamental280frequency). we see clear peaks showing the high responses from that subject which means281that this subject is a good candidate for SSVEP BCI control. In many cases the peaks282are weak or absent, meaning that subject do not achieve the task correctly.283284285"""286287288"""289# Create Layers and model290291Create Layers in a cross-framework custom component fashion.292In the SSVEPFormer, the data is first transformed to the frequency domain through293Fast-Fourier transform (FFT), to construct a complex spectrum presentation consisting of294the concatenation of frequency and phase information in a fixed frequency band. To keep295the model in an end-to-end format, we implement the complex spectrum transformation as296non-trainable layer.297298299The SSVEPFormer unlike the Transformer architecture does not contain positional encoding/embedding300layers which replaced a channel combination block that has a layer of Conv1D layer of 1301kernel size with double input channels (double the count of electrodes) number of filters,302and LayerNorm, Gelu activation and dropout.303Another difference with Transformers is the absence of multi-head attention layers with304attention mechanism.305The model encoder contains two identical and successive blocks. Each block has two306sub-blocks of CNN module and MLP module. the CNN module consists of a LayerNorm, Conv1D307with the same number of filters as channel combination, LayerNorm, Gelu, Dropout and an308residual connection. The MLP module consists of a LayerNorm, Dense layer, Gelu, droput309and residual connection. the Dense layer is applied on each channel separately.310The last block of the model is MLP head with Flatten layer, Dropout, Dense, LayerNorm,311Gelu, Dropout and Dense layer with softmax acitvation.312All trainable weights are initialized by a normal distribution with 0 mean and 0.01313standard deviation as state in the original paper.314"""315316317class ComplexSpectrum(keras.layers.Layer):318def __init__(self, nfft=512, fft_start=8, fft_end=64):319super().__init__()320self.nfft = nfft321self.fft_start = fft_start322self.fft_end = fft_end323324def call(self, x):325samples = x.shape[-1]326x = keras.ops.rfft(x, fft_length=self.nfft)327real = x[0] / samples328imag = x[1] / samples329real = real[:, :, self.fft_start : self.fft_end]330imag = imag[:, :, self.fft_start : self.fft_end]331x = keras.ops.concatenate((real, imag), axis=-1)332return x333334335class ChannelComb(keras.layers.Layer):336def __init__(self, n_channels, drop_rate=0.5):337super().__init__()338self.conv = layers.Conv1D(3392 * n_channels,3401,341padding="same",342kernel_initializer=keras.initializers.RandomNormal(343mean=0.0, stddev=0.01, seed=None344),345)346self.normalization = layers.LayerNormalization()347self.activation = layers.Activation(activation="gelu")348self.drop = layers.Dropout(drop_rate)349350def call(self, x):351x = self.conv(x)352x = self.normalization(x)353x = self.activation(x)354x = self.drop(x)355return x356357358class ConvAttention(keras.layers.Layer):359def __init__(self, n_channels, drop_rate=0.5):360super().__init__()361self.norm = layers.LayerNormalization()362self.conv = layers.Conv1D(3632 * n_channels,36431,365padding="same",366kernel_initializer=keras.initializers.RandomNormal(367mean=0.0, stddev=0.01, seed=None368),369)370self.activation = layers.Activation(activation="gelu")371self.drop = layers.Dropout(drop_rate)372373def call(self, x):374input = x375x = self.norm(x)376x = self.conv(x)377x = self.activation(x)378x = self.drop(x)379x = x + input380return x381382383class ChannelMLP(keras.layers.Layer):384def __init__(self, n_features, drop_rate=0.5):385super().__init__()386self.norm = layers.LayerNormalization()387self.mlp = layers.Dense(3882 * n_features,389kernel_initializer=keras.initializers.RandomNormal(390mean=0.0, stddev=0.01, seed=None391),392)393self.activation = layers.Activation(activation="gelu")394self.drop = layers.Dropout(drop_rate)395self.cat = layers.Concatenate(axis=1)396397def call(self, x):398input = x399channels = x.shape[1] # x shape : NCF400x = self.norm(x)401output_channels = []402for i in range(channels):403c = self.mlp(x[:, :, i])404c = layers.Reshape([1, -1])(c)405output_channels.append(c)406x = self.cat(output_channels)407x = self.activation(x)408x = self.drop(x)409x = x + input410return x411412413class Encoder(keras.layers.Layer):414def __init__(self, n_channels, n_features, drop_rate=0.5):415super().__init__()416self.attention1 = ConvAttention(n_channels, drop_rate=drop_rate)417self.mlp1 = ChannelMLP(n_features, drop_rate=drop_rate)418self.attention2 = ConvAttention(n_channels, drop_rate=drop_rate)419self.mlp2 = ChannelMLP(n_features, drop_rate=drop_rate)420421def call(self, x):422x = self.attention1(x)423x = self.mlp1(x)424x = self.attention2(x)425x = self.mlp2(x)426return x427428429class MlpHead(keras.layers.Layer):430def __init__(self, n_classes, drop_rate=0.5):431super().__init__()432self.flatten = layers.Flatten()433self.drop = layers.Dropout(drop_rate)434self.linear1 = layers.Dense(4356 * n_classes,436kernel_initializer=keras.initializers.RandomNormal(437mean=0.0, stddev=0.01, seed=None438),439)440self.norm = layers.LayerNormalization()441self.activation = layers.Activation(activation="gelu")442self.drop2 = layers.Dropout(drop_rate)443self.linear2 = layers.Dense(444n_classes,445kernel_initializer=keras.initializers.RandomNormal(446mean=0.0, stddev=0.01, seed=None447),448)449450def call(self, x):451x = self.flatten(x)452x = self.drop(x)453x = self.linear1(x)454x = self.norm(x)455x = self.activation(x)456x = self.drop2(x)457x = self.linear2(x)458return x459460461"""462### Create a sequential model with the layers above463"""464465466def create_ssvepformer(467input_shape, fs, resolution, fq_band, n_channels, n_classes, drop_rate468):469nfft = round(fs / resolution)470fft_start = int(fq_band[0] / resolution)471fft_end = int(fq_band[1] / resolution) + 1472n_features = fft_end - fft_start473474model = keras.Sequential(475[476keras.Input(shape=input_shape),477ComplexSpectrum(nfft, fft_start, fft_end),478ChannelComb(n_channels=n_channels, drop_rate=drop_rate),479Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),480Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),481MlpHead(n_classes=n_classes, drop_rate=drop_rate),482layers.Activation(activation="softmax"),483]484)485486return model487488489"""490# Evaluation491"""492493# Training settings same as the original paper494BATCH_SIZE = 128495EPOCHS = 100496LR = 0.001 # learning rate497WD = 0.001 # weight decay498MOMENTUM = 0.9499DROP_RATE = 0.5500501resolution = 0.25502503"""504From the entire dataset we select folds for each subject evaluation.505construct a tf dataset object for train and testing data and create the model and launch506the training using SGD optimizer.507"""508509510def concatenate_subjects(x, y, fold):511X = np.concatenate([x[idx] for idx in fold], axis=-1)512Y = np.concatenate([y[idx] for idx in fold], axis=-1)513X = X.transpose((2, 1, 0)) # trials x channels x samples514return X, Y - 1 # transform labels to values from 0...11515516517def evaluate_subject(518x_train,519y_train,520x_val,521y_val,522input_shape,523fs=256,524resolution=0.25,525band=[8, 64],526channels=8,527n_classes=12,528drop_rate=DROP_RATE,529):530531train_dataset = (532tf.data.Dataset.from_tensor_slices((x_train, y_train))533.batch(BATCH_SIZE)534.prefetch(tf.data.AUTOTUNE)535)536537test_dataset = (538tf.data.Dataset.from_tensor_slices((x_val, y_val))539.batch(BATCH_SIZE)540.prefetch(tf.data.AUTOTUNE)541)542543model = create_ssvepformer(544input_shape, fs, resolution, band, channels, n_classes, drop_rate545)546sgd = keras.optimizers.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WD)547548model.compile(549loss="sparse_categorical_crossentropy",550optimizer=sgd,551metrics=["accuracy"],552jit_compile=True,553)554555history = model.fit(556train_dataset,557batch_size=BATCH_SIZE,558epochs=EPOCHS,559validation_data=test_dataset,560verbose=0,561)562loss, acc = model.evaluate(test_dataset)563return acc * 100564565566"""567## Run evaluation568"""569570channels = X.shape[2]571samples = X.shape[1]572input_shape = (channels, samples)573n_classes = 12574575model = create_ssvepformer(576input_shape, fs, resolution, band, channels, n_classes, DROP_RATE577)578model.summary()579580"""581## Evaluation on all subjects following a leave-one-subject out data repartition scheme582"""583584accs = np.zeros(10)585586for subject in range(10):587print(f"Testing subject: {subject+ 1}")588589# create train / test folds590folds = np.delete(np.arange(10), subject)591train_index = folds592test_index = [subject]593594# create data split for each subject595x_train, y_train = concatenate_subjects(X, Y, train_index)596x_val, y_val = concatenate_subjects(X, Y, test_index)597598# train and evaluate a fold and compute the time it takes599acc = evaluate_subject(x_train, y_train, x_val, y_val, input_shape)600601accs[subject] = acc602603print(f"\nAccuracy Across Subjects: {accs.mean()} % std: {np.std(accs)}")604605"""606and that's it! we see how some subjects with no data on the training set still can achieve607almost a 100% correct commands and others show poor performance around 50%. In the original608paper using PyTorch the average accuracy was 84.04% with 17.37 std. we reached the same609values knowing the stochastic nature of deep learning.610"""611612"""613# Visualizations614615Training and inference times comparison between the different backends (Jax, Tensorflow616and PyTorch) on the three GPUs available with Colab Free/Pro/Pro+: T4, L4, A100.617618619"""620621"""622## Training Time623624625"""626627"""628# Inference Time629630631"""632633"""634the Jax backend was the best on training and inference in all the GPUs, the PyTorch was635exremely slow due to the jit compilation option being disable because of the complex636data type calculated by FFT which is not supported by the PyTorch jit compiler.637"""638639"""640# Acknowledgment641642I thank Chris Perry [X](https://x.com/thechrisperry) @GoogleColab for supporting this643work with GPU compute.644"""645646"""647# References648[1] Chen, J. et al. (2023) ‘A transformer-based deep neural network model for SSVEP649classification’, Neural Networks, 164, pp. 521–534. Available at: https://doi.org/10.1016/j.neunet.2023.04.045.650651[2] Nakanishi, M. et al. (2015) ‘A Comparison Study of Canonical Correlation Analysis652Based Methods for Detecting Steady-State Visual Evoked Potentials’, Plos One, 10(10), p.653e0140703. Available at: https://doi.org/10.1371/journal.pone.0140703654"""655656657