Path: blob/master/examples/timeseries/timeseries_classification_from_scratch.py
3507 views
"""1Title: Timeseries classification from scratch2Author: [hfawaz](https://github.com/hfawaz/)3Date created: 2020/07/214Last modified: 2023/11/105Description: Training a timeseries classifier from scratch on the FordA dataset from the UCR/UEA archive.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how to do timeseries classification from scratch, starting from raw13CSV timeseries files on disk. We demonstrate the workflow on the FordA dataset from the14[UCR/UEA archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/).1516"""1718"""19## Setup2021"""22import keras23import numpy as np24import matplotlib.pyplot as plt2526"""27## Load the data: the FordA dataset2829### Dataset description3031The dataset we are using here is called FordA.32The data comes from the UCR archive.33The dataset contains 3601 training instances and another 1320 testing instances.34Each timeseries corresponds to a measurement of engine noise captured by a motor sensor.35For this task, the goal is to automatically detect the presence of a specific issue with36the engine. The problem is a balanced binary classification task. The full description of37this dataset can be found [here](http://www.j-wichard.de/publications/FordPaper.pdf).3839### Read the TSV data4041We will use the `FordA_TRAIN` file for training and the42`FordA_TEST` file for testing. The simplicity of this dataset43allows us to demonstrate effectively how to use ConvNets for timeseries classification.44In this file, the first column corresponds to the label.45"""464748def readucr(filename):49data = np.loadtxt(filename, delimiter="\t")50y = data[:, 0]51x = data[:, 1:]52return x, y.astype(int)535455root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"5657x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")58x_test, y_test = readucr(root_url + "FordA_TEST.tsv")5960"""61## Visualize the data6263Here we visualize one timeseries example for each class in the dataset.6465"""6667classes = np.unique(np.concatenate((y_train, y_test), axis=0))6869plt.figure()70for c in classes:71c_x_train = x_train[y_train == c]72plt.plot(c_x_train[0], label="class " + str(c))73plt.legend(loc="best")74plt.show()75plt.close()7677"""78## Standardize the data7980Our timeseries are already in a single length (500). However, their values are81usually in various ranges. This is not ideal for a neural network;82in general we should seek to make the input values normalized.83For this specific dataset, the data is already z-normalized: each timeseries sample84has a mean equal to zero and a standard deviation equal to one. This type of85normalization is very common for timeseries classification problems, see86[Bagnall et al. (2016)](https://link.springer.com/article/10.1007/s10618-016-0483-9).8788Note that the timeseries data used here are univariate, meaning we only have one channel89per timeseries example.90We will therefore transform the timeseries into a multivariate one with one channel91using a simple reshaping via numpy.92This will allow us to construct a model that is easily applicable to multivariate time93series.94"""9596x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))97x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))9899"""100Finally, in order to use `sparse_categorical_crossentropy`, we will have to count101the number of classes beforehand.102"""103104num_classes = len(np.unique(y_train))105106"""107Now we shuffle the training set because we will be using the `validation_split` option108later when training.109"""110111idx = np.random.permutation(len(x_train))112x_train = x_train[idx]113y_train = y_train[idx]114115"""116Standardize the labels to positive integers.117The expected labels will then be 0 and 1.118"""119120y_train[y_train == -1] = 0121y_test[y_test == -1] = 0122123"""124## Build a model125126We build a Fully Convolutional Neural Network originally proposed in127[this paper](https://arxiv.org/abs/1611.06455).128The implementation is based on the TF 2 version provided129[here](https://github.com/hfawaz/dl-4-tsc/).130The following hyperparameters (kernel_size, filters, the usage of BatchNorm) were found131via random search using [KerasTuner](https://github.com/keras-team/keras-tuner).132133"""134135136def make_model(input_shape):137input_layer = keras.layers.Input(input_shape)138139conv1 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(input_layer)140conv1 = keras.layers.BatchNormalization()(conv1)141conv1 = keras.layers.ReLU()(conv1)142143conv2 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv1)144conv2 = keras.layers.BatchNormalization()(conv2)145conv2 = keras.layers.ReLU()(conv2)146147conv3 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv2)148conv3 = keras.layers.BatchNormalization()(conv3)149conv3 = keras.layers.ReLU()(conv3)150151gap = keras.layers.GlobalAveragePooling1D()(conv3)152153output_layer = keras.layers.Dense(num_classes, activation="softmax")(gap)154155return keras.models.Model(inputs=input_layer, outputs=output_layer)156157158model = make_model(input_shape=x_train.shape[1:])159keras.utils.plot_model(model, show_shapes=True)160161"""162## Train the model163164"""165166epochs = 500167batch_size = 32168169callbacks = [170keras.callbacks.ModelCheckpoint(171"best_model.keras", save_best_only=True, monitor="val_loss"172),173keras.callbacks.ReduceLROnPlateau(174monitor="val_loss", factor=0.5, patience=20, min_lr=0.0001175),176keras.callbacks.EarlyStopping(monitor="val_loss", patience=50, verbose=1),177]178model.compile(179optimizer="adam",180loss="sparse_categorical_crossentropy",181metrics=["sparse_categorical_accuracy"],182)183history = model.fit(184x_train,185y_train,186batch_size=batch_size,187epochs=epochs,188callbacks=callbacks,189validation_split=0.2,190verbose=1,191)192193"""194## Evaluate model on test data195"""196197model = keras.models.load_model("best_model.keras")198199test_loss, test_acc = model.evaluate(x_test, y_test)200201print("Test accuracy", test_acc)202print("Test loss", test_loss)203204"""205## Plot the model's training and validation loss206"""207208metric = "sparse_categorical_accuracy"209plt.figure()210plt.plot(history.history[metric])211plt.plot(history.history["val_" + metric])212plt.title("model " + metric)213plt.ylabel(metric, fontsize="large")214plt.xlabel("epoch", fontsize="large")215plt.legend(["train", "val"], loc="best")216plt.show()217plt.close()218219"""220We can see how the training accuracy reaches almost 0.95 after 100 epochs.221However, by observing the validation accuracy we can see how the network still needs222training until it reaches almost 0.97 for both the validation and the training accuracy223after 200 epochs. Beyond the 200th epoch, if we continue on training, the validation224accuracy will start decreasing while the training accuracy will continue on increasing:225the model starts overfitting.226"""227228229