Path: blob/master/examples/timeseries/timeseries_classification_transformer.py
3507 views
"""1Title: Timeseries classification with a Transformer model2Author: [Theodoros Ntakouris](https://github.com/ntakouris)3Date created: 2021/06/254Last modified: 2021/08/055Description: This notebook demonstrates how to do timeseries classification using a Transformer model.6Accelerator: GPU7"""89"""10## Introduction1112This is the Transformer architecture from13[Attention Is All You Need](https://arxiv.org/abs/1706.03762),14applied to timeseries instead of natural language.1516This example requires TensorFlow 2.4 or higher.1718## Load the dataset1920We are going to use the same dataset and preprocessing as the21[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)22example.23"""2425import numpy as np26import keras27from keras import layers282930def readucr(filename):31data = np.loadtxt(filename, delimiter="\t")32y = data[:, 0]33x = data[:, 1:]34return x, y.astype(int)353637root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"3839x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")40x_test, y_test = readucr(root_url + "FordA_TEST.tsv")4142x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))43x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))4445n_classes = len(np.unique(y_train))4647idx = np.random.permutation(len(x_train))48x_train = x_train[idx]49y_train = y_train[idx]5051y_train[y_train == -1] = 052y_test[y_test == -1] = 05354"""55## Build the model5657Our model processes a tensor of shape `(batch size, sequence length, features)`,58where `sequence length` is the number of time steps and `features` is each input59timeseries.6061You can replace your classification RNN layers with this one: the62inputs are fully compatible!6364We include residual connections, layer normalization, and dropout.65The resulting layer can be stacked multiple times.6667The projection layers are implemented through `keras.layers.Conv1D`.68"""6970# This implementation applies Layer Normalization before the residual connection71# to improve training stability by producing better-behaved gradients and often72# eliminating the need for learning rate warm-up.737475def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):76# Attention and Normalization77x = layers.MultiHeadAttention(78key_dim=head_size, num_heads=num_heads, dropout=dropout79)(inputs, inputs)80x = layers.Dropout(dropout)(x)81x = layers.LayerNormalization(epsilon=1e-6)(x)82res = x + inputs8384# Feed Forward Part85x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)86x = layers.Dropout(dropout)(x)87x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)88x = layers.LayerNormalization(epsilon=1e-6)(x)89return x + res909192"""93The main part of our model is now complete. We can stack multiple of those94`transformer_encoder` blocks and we can also proceed to add the final95Multi-Layer Perceptron classification head. Apart from a stack of `Dense`96layers, we need to reduce the output tensor of the `TransformerEncoder` part of97our model down to a vector of features for each data point in the current98batch. A common way to achieve this is to use a pooling layer. For99this example, a `GlobalAveragePooling1D` layer is sufficient.100"""101102103def build_model(104input_shape,105head_size,106num_heads,107ff_dim,108num_transformer_blocks,109mlp_units,110dropout=0,111mlp_dropout=0,112):113inputs = keras.Input(shape=input_shape)114x = inputs115for _ in range(num_transformer_blocks):116x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)117118x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)119for dim in mlp_units:120x = layers.Dense(dim, activation="relu")(x)121x = layers.Dropout(mlp_dropout)(x)122outputs = layers.Dense(n_classes, activation="softmax")(x)123return keras.Model(inputs, outputs)124125126"""127## Train and evaluate128"""129130input_shape = x_train.shape[1:]131132model = build_model(133input_shape,134head_size=256,135num_heads=4,136ff_dim=4,137num_transformer_blocks=4,138mlp_units=[128],139mlp_dropout=0.4,140dropout=0.25,141)142143model.compile(144loss="sparse_categorical_crossentropy",145optimizer=keras.optimizers.Adam(learning_rate=1e-4),146metrics=["sparse_categorical_accuracy"],147)148model.summary()149150callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]151152model.fit(153x_train,154y_train,155validation_split=0.2,156epochs=150,157batch_size=64,158callbacks=callbacks,159)160161model.evaluate(x_test, y_test, verbose=1)162163"""164## Conclusions165166In about 110-120 epochs (25s each on Colab), the model reaches a training167accuracy of ~0.95, validation accuracy of ~84 and a testing168accuracy of ~85, without hyperparameter tuning. And that is for a model169with less than 100k parameters. Of course, parameter count and accuracy could be170improved by a hyperparameter search and a more sophisticated learning rate171schedule, or a different optimizer.172173"""174175176