Path: blob/master/examples/timeseries/timeseries_traffic_forecasting.py
3507 views
"""1Title: Traffic forecasting using graph neural networks and LSTM2Author: [Arash Khodadadi](https://www.linkedin.com/in/arash-khodadadi-08a02490/)3Date created: 2021/12/284Last modified: 2023/11/225Description: This example demonstrates how to do timeseries forecasting over graphs.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how to forecast traffic condition using graph neural networks and LSTM.13Specifically, we are interested in predicting the future values of the traffic speed given14a history of the traffic speed for a collection of road segments.1516One popular method to17solve this problem is to consider each road segment's traffic speed as a separate18timeseries and predict the future values of each timeseries19using the past values of the same timeseries.2021This method, however, ignores the dependency of the traffic speed of one road segment on22the neighboring segments. To be able to take into account the complex interactions between23the traffic speed on a collection of neighboring roads, we can define the traffic network24as a graph and consider the traffic speed as a signal on this graph. In this example,25we implement a neural network architecture which can process timeseries data over a graph.26We first show how to process the data and create a27[tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for28forecasting over graphs. Then, we implement a model which uses graph convolution and29LSTM layers to perform forecasting over a graph.3031The data processing and the model architecture are inspired by this paper:3233Yu, Bing, Haoteng Yin, and Zhanxing Zhu. "Spatio-temporal graph convolutional networks:34a deep learning framework for traffic forecasting." Proceedings of the 27th International35Joint Conference on Artificial Intelligence, 2018.36([github](https://github.com/VeritasYin/STGCN_IJCAI-18))37"""3839"""40## Setup41"""4243import os4445os.environ["KERAS_BACKEND"] = "tensorflow"4647import pandas as pd48import numpy as np49import typing50import matplotlib.pyplot as plt5152import tensorflow as tf53import keras54from keras import layers55from keras import ops5657"""58## Data preparation59"""6061"""62### Data description6364We use a real-world traffic speed dataset named `PeMSD7`. We use the version65collected and prepared by [Yu et al., 2018](https://arxiv.org/abs/1709.04875)66and available67[here](https://github.com/VeritasYin/STGCN_IJCAI-18/tree/master/dataset).6869The data consists of two files:7071- `PeMSD7_W_228.csv` contains the distances between 22872stations across the District 7 of California.73- `PeMSD7_V_228.csv` contains traffic74speed collected for those stations in the weekdays of May and June of 2012.7576The full description of the dataset can be found in77[Yu et al., 2018](https://arxiv.org/abs/1709.04875).78"""7980"""81### Loading data82"""8384url = "https://github.com/VeritasYin/STGCN_IJCAI-18/raw/master/dataset/PeMSD7_Full.zip"85data_dir = keras.utils.get_file(origin=url, extract=True, archive_format="zip")86data_dir = data_dir.rstrip("PeMSD7_Full.zip")8788route_distances = pd.read_csv(89os.path.join(data_dir, "PeMSD7_W_228.csv"), header=None90).to_numpy()91speeds_array = pd.read_csv(92os.path.join(data_dir, "PeMSD7_V_228.csv"), header=None93).to_numpy()9495print(f"route_distances shape={route_distances.shape}")96print(f"speeds_array shape={speeds_array.shape}")9798"""99### sub-sampling roads100101To reduce the problem size and make the training faster, we will only102work with a sample of 26 roads out of the 228 roads in the dataset.103We have chosen the roads by starting from road 0, choosing the 5 closest104roads to it, and continuing this process until we get 25 roads. You can choose105any other subset of the roads. We chose the roads in this way to increase the likelihood106of having roads with correlated speed timeseries.107`sample_routes` contains the IDs of the selected roads.108"""109110sample_routes = [1110,1121,1134,1147,1158,11611,11715,118108,119109,120114,121115,122118,123120,124123,125124,126126,127127,128129,129130,130132,131133,132136,133139,134144,135147,136216,137]138route_distances = route_distances[np.ix_(sample_routes, sample_routes)]139speeds_array = speeds_array[:, sample_routes]140141print(f"route_distances shape={route_distances.shape}")142print(f"speeds_array shape={speeds_array.shape}")143144"""145### Data visualization146147Here are the timeseries of the traffic speed for two of the routes:148"""149150plt.figure(figsize=(18, 6))151plt.plot(speeds_array[:, [0, -1]])152plt.legend(["route_0", "route_25"])153154"""155We can also visualize the correlation between the timeseries in different routes.156"""157158plt.figure(figsize=(8, 8))159plt.matshow(np.corrcoef(speeds_array.T), 0)160plt.xlabel("road number")161plt.ylabel("road number")162163"""164Using this correlation heatmap, we can see that for example the speed in165routes 4, 5, 6 are highly correlated.166"""167168"""169### Splitting and normalizing data170171Next, we split the speed values array into train/validation/test sets,172and normalize the resulting arrays:173"""174175train_size, val_size = 0.5, 0.2176177178def preprocess(data_array: np.ndarray, train_size: float, val_size: float):179"""Splits data into train/val/test sets and normalizes the data.180181Args:182data_array: ndarray of shape `(num_time_steps, num_routes)`183train_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset184to include in the train split.185val_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset186to include in the validation split.187188Returns:189`train_array`, `val_array`, `test_array`190"""191192num_time_steps = data_array.shape[0]193num_train, num_val = (194int(num_time_steps * train_size),195int(num_time_steps * val_size),196)197train_array = data_array[:num_train]198mean, std = train_array.mean(axis=0), train_array.std(axis=0)199200train_array = (train_array - mean) / std201val_array = (data_array[num_train : (num_train + num_val)] - mean) / std202test_array = (data_array[(num_train + num_val) :] - mean) / std203204return train_array, val_array, test_array205206207train_array, val_array, test_array = preprocess(speeds_array, train_size, val_size)208209print(f"train set size: {train_array.shape}")210print(f"validation set size: {val_array.shape}")211print(f"test set size: {test_array.shape}")212213"""214### Creating TensorFlow Datasets215216Next, we create the datasets for our forecasting problem. The forecasting problem217can be stated as follows: given a sequence of the218road speed values at times `t+1, t+2, ..., t+T`, we want to predict the future values of219the roads speed for times `t+T+1, ..., t+T+h`. So for each time `t` the inputs to our220model are `T` vectors each of size `N` and the targets are `h` vectors each of size `N`,221where `N` is the number of roads.222"""223224"""225We use the Keras built-in function226`keras.utils.timeseries_dataset_from_array`.227The function `create_tf_dataset()` below takes as input a `numpy.ndarray` and returns a228`tf.data.Dataset`. In this function `input_sequence_length=T` and `forecast_horizon=h`.229230The argument `multi_horizon` needs more explanation. Assume `forecast_horizon=3`.231If `multi_horizon=True` then the model will make a forecast for time steps232`t+T+1, t+T+2, t+T+3`. So the target will have shape `(T,3)`. But if233`multi_horizon=False`, the model will make a forecast only for time step `t+T+3` and234so the target will have shape `(T, 1)`.235236You may notice that the input tensor in each batch has shape237`(batch_size, input_sequence_length, num_routes, 1)`. The last dimension is added to238make the model more general: at each time step, the input features for each raod may239contain multiple timeseries. For instance, one might want to use temperature timeseries240in addition to historical values of the speed as input features. In this example,241however, the last dimension of the input is always 1.242243We use the last 12 values of the speed in each road to forecast the speed for 3 time244steps ahead:245"""246247batch_size = 64248input_sequence_length = 12249forecast_horizon = 3250multi_horizon = False251252253def create_tf_dataset(254data_array: np.ndarray,255input_sequence_length: int,256forecast_horizon: int,257batch_size: int = 128,258shuffle=True,259multi_horizon=True,260):261"""Creates tensorflow dataset from numpy array.262263This function creates a dataset where each element is a tuple `(inputs, targets)`.264`inputs` is a Tensor265of shape `(batch_size, input_sequence_length, num_routes, 1)` containing266the `input_sequence_length` past values of the timeseries for each node.267`targets` is a Tensor of shape `(batch_size, forecast_horizon, num_routes)`268containing the `forecast_horizon`269future values of the timeseries for each node.270271Args:272data_array: np.ndarray with shape `(num_time_steps, num_routes)`273input_sequence_length: Length of the input sequence (in number of timesteps).274forecast_horizon: If `multi_horizon=True`, the target will be the values of the timeseries for 1 to275`forecast_horizon` timesteps ahead. If `multi_horizon=False`, the target will be the value of the276timeseries `forecast_horizon` steps ahead (only one value).277batch_size: Number of timeseries samples in each batch.278shuffle: Whether to shuffle output samples, or instead draw them in chronological order.279multi_horizon: See `forecast_horizon`.280281Returns:282A tf.data.Dataset instance.283"""284285inputs = keras.utils.timeseries_dataset_from_array(286np.expand_dims(data_array[:-forecast_horizon], axis=-1),287None,288sequence_length=input_sequence_length,289shuffle=False,290batch_size=batch_size,291)292293target_offset = (294input_sequence_length295if multi_horizon296else input_sequence_length + forecast_horizon - 1297)298target_seq_length = forecast_horizon if multi_horizon else 1299targets = keras.utils.timeseries_dataset_from_array(300data_array[target_offset:],301None,302sequence_length=target_seq_length,303shuffle=False,304batch_size=batch_size,305)306307dataset = tf.data.Dataset.zip((inputs, targets))308if shuffle:309dataset = dataset.shuffle(100)310311return dataset.prefetch(16).cache()312313314train_dataset, val_dataset = (315create_tf_dataset(data_array, input_sequence_length, forecast_horizon, batch_size)316for data_array in [train_array, val_array]317)318319test_dataset = create_tf_dataset(320test_array,321input_sequence_length,322forecast_horizon,323batch_size=test_array.shape[0],324shuffle=False,325multi_horizon=multi_horizon,326)327328329"""330### Roads Graph331332As mentioned before, we assume that the road segments form a graph.333The `PeMSD7` dataset has the road segments distance. The next step334is to create the graph adjacency matrix from these distances. Following335[Yu et al., 2018](https://arxiv.org/abs/1709.04875) (equation 10) we assume there336is an edge between two nodes in the graph if the distance between the corresponding roads337is less than a threshold.338"""339340341def compute_adjacency_matrix(342route_distances: np.ndarray, sigma2: float, epsilon: float343):344"""Computes the adjacency matrix from distances matrix.345346It uses the formula in https://github.com/VeritasYin/STGCN_IJCAI-18#data-preprocessing to347compute an adjacency matrix from the distance matrix.348The implementation follows that paper.349350Args:351route_distances: np.ndarray of shape `(num_routes, num_routes)`. Entry `i,j` of this array is the352distance between roads `i,j`.353sigma2: Determines the width of the Gaussian kernel applied to the square distances matrix.354epsilon: A threshold specifying if there is an edge between two nodes. Specifically, `A[i,j]=1`355if `np.exp(-w2[i,j] / sigma2) >= epsilon` and `A[i,j]=0` otherwise, where `A` is the adjacency356matrix and `w2=route_distances * route_distances`357358Returns:359A boolean graph adjacency matrix.360"""361num_routes = route_distances.shape[0]362route_distances = route_distances / 10000.0363w2, w_mask = (364route_distances * route_distances,365np.ones([num_routes, num_routes]) - np.identity(num_routes),366)367return (np.exp(-w2 / sigma2) >= epsilon) * w_mask368369370"""371The function `compute_adjacency_matrix()` returns a boolean adjacency matrix372where 1 means there is an edge between two nodes. We use the following class373to store the information about the graph.374"""375376377class GraphInfo:378def __init__(self, edges: typing.Tuple[list, list], num_nodes: int):379self.edges = edges380self.num_nodes = num_nodes381382383sigma2 = 0.1384epsilon = 0.5385adjacency_matrix = compute_adjacency_matrix(route_distances, sigma2, epsilon)386node_indices, neighbor_indices = np.where(adjacency_matrix == 1)387graph = GraphInfo(388edges=(node_indices.tolist(), neighbor_indices.tolist()),389num_nodes=adjacency_matrix.shape[0],390)391print(f"number of nodes: {graph.num_nodes}, number of edges: {len(graph.edges[0])}")392393"""394## Network architecture395396Our model for forecasting over the graph consists of a graph convolution397layer and a LSTM layer.398"""399400"""401### Graph convolution layer402403Our implementation of the graph convolution layer resembles the implementation404in [this Keras example](https://keras.io/examples/graph/gnn_citations/). Note that405in that example input to the layer is a 2D tensor of shape `(num_nodes,in_feat)`406but in our example the input to the layer is a 4D tensor of shape407`(num_nodes, batch_size, input_seq_length, in_feat)`. The graph convolution layer408performs the following steps:409410- The nodes' representations are computed in `self.compute_nodes_representation()`411by multiplying the input features by `self.weight`412- The aggregated neighbors' messages are computed in `self.compute_aggregated_messages()`413by first aggregating the neighbors' representations and then multiplying the results by414`self.weight`415- The final output of the layer is computed in `self.update()` by combining the nodes416representations and the neighbors' aggregated messages417"""418419420class GraphConv(layers.Layer):421def __init__(422self,423in_feat,424out_feat,425graph_info: GraphInfo,426aggregation_type="mean",427combination_type="concat",428activation: typing.Optional[str] = None,429**kwargs,430):431super().__init__(**kwargs)432self.in_feat = in_feat433self.out_feat = out_feat434self.graph_info = graph_info435self.aggregation_type = aggregation_type436self.combination_type = combination_type437self.weight = self.add_weight(438initializer=keras.initializers.GlorotUniform(),439shape=(in_feat, out_feat),440dtype="float32",441trainable=True,442)443self.activation = layers.Activation(activation)444445def aggregate(self, neighbour_representations):446aggregation_func = {447"sum": tf.math.unsorted_segment_sum,448"mean": tf.math.unsorted_segment_mean,449"max": tf.math.unsorted_segment_max,450}.get(self.aggregation_type)451452if aggregation_func:453return aggregation_func(454neighbour_representations,455self.graph_info.edges[0],456num_segments=self.graph_info.num_nodes,457)458459raise ValueError(f"Invalid aggregation type: {self.aggregation_type}")460461def compute_nodes_representation(self, features):462"""Computes each node's representation.463464The nodes' representations are obtained by multiplying the features tensor with465`self.weight`. Note that466`self.weight` has shape `(in_feat, out_feat)`.467468Args:469features: Tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`470471Returns:472A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`473"""474return ops.matmul(features, self.weight)475476def compute_aggregated_messages(self, features):477neighbour_representations = tf.gather(features, self.graph_info.edges[1])478aggregated_messages = self.aggregate(neighbour_representations)479return ops.matmul(aggregated_messages, self.weight)480481def update(self, nodes_representation, aggregated_messages):482if self.combination_type == "concat":483h = ops.concatenate([nodes_representation, aggregated_messages], axis=-1)484elif self.combination_type == "add":485h = nodes_representation + aggregated_messages486else:487raise ValueError(f"Invalid combination type: {self.combination_type}.")488return self.activation(h)489490def call(self, features):491"""Forward pass.492493Args:494features: tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`495496Returns:497A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`498"""499nodes_representation = self.compute_nodes_representation(features)500aggregated_messages = self.compute_aggregated_messages(features)501return self.update(nodes_representation, aggregated_messages)502503504"""505### LSTM plus graph convolution506507By applying the graph convolution layer to the input tensor, we get another tensor508containing the nodes' representations over time (another 4D tensor). For each time509step, a node's representation is informed by the information from its neighbors.510511To make good forecasts, however, we need not only information from the neighbors512but also we need to process the information over time. To this end, we can pass each513node's tensor through a recurrent layer. The `LSTMGC` layer below, first applies514a graph convolution layer to the inputs and then passes the results through a515`LSTM` layer.516"""517518519class LSTMGC(layers.Layer):520"""Layer comprising a convolution layer followed by LSTM and dense layers."""521522def __init__(523self,524in_feat,525out_feat,526lstm_units: int,527input_seq_len: int,528output_seq_len: int,529graph_info: GraphInfo,530graph_conv_params: typing.Optional[dict] = None,531**kwargs,532):533super().__init__(**kwargs)534535# graph conv layer536if graph_conv_params is None:537graph_conv_params = {538"aggregation_type": "mean",539"combination_type": "concat",540"activation": None,541}542self.graph_conv = GraphConv(in_feat, out_feat, graph_info, **graph_conv_params)543544self.lstm = layers.LSTM(lstm_units, activation="relu")545self.dense = layers.Dense(output_seq_len)546547self.input_seq_len, self.output_seq_len = input_seq_len, output_seq_len548549def call(self, inputs):550"""Forward pass.551552Args:553inputs: tensor of shape `(batch_size, input_seq_len, num_nodes, in_feat)`554555Returns:556A tensor of shape `(batch_size, output_seq_len, num_nodes)`.557"""558559# convert shape to (num_nodes, batch_size, input_seq_len, in_feat)560inputs = ops.transpose(inputs, [2, 0, 1, 3])561562gcn_out = self.graph_conv(563inputs564) # gcn_out has shape: (num_nodes, batch_size, input_seq_len, out_feat)565shape = ops.shape(gcn_out)566num_nodes, batch_size, input_seq_len, out_feat = (567shape[0],568shape[1],569shape[2],570shape[3],571)572573# LSTM takes only 3D tensors as input574gcn_out = ops.reshape(575gcn_out, (batch_size * num_nodes, input_seq_len, out_feat)576)577lstm_out = self.lstm(578gcn_out579) # lstm_out has shape: (batch_size * num_nodes, lstm_units)580581dense_output = self.dense(582lstm_out583) # dense_output has shape: (batch_size * num_nodes, output_seq_len)584output = ops.reshape(dense_output, (num_nodes, batch_size, self.output_seq_len))585return ops.transpose(586output, [1, 2, 0]587) # returns Tensor of shape (batch_size, output_seq_len, num_nodes)588589590"""591## Model training592"""593594in_feat = 1595batch_size = 64596epochs = 20597input_sequence_length = 12598forecast_horizon = 3599multi_horizon = False600out_feat = 10601lstm_units = 64602graph_conv_params = {603"aggregation_type": "mean",604"combination_type": "concat",605"activation": None,606}607608st_gcn = LSTMGC(609in_feat,610out_feat,611lstm_units,612input_sequence_length,613forecast_horizon,614graph,615graph_conv_params,616)617inputs = layers.Input((input_sequence_length, graph.num_nodes, in_feat))618outputs = st_gcn(inputs)619620model = keras.models.Model(inputs, outputs)621model.compile(622optimizer=keras.optimizers.RMSprop(learning_rate=0.0002),623loss=keras.losses.MeanSquaredError(),624)625model.fit(626train_dataset,627validation_data=val_dataset,628epochs=epochs,629callbacks=[keras.callbacks.EarlyStopping(patience=10)],630)631632"""633## Making forecasts on test set634635Now we can use the trained model to make forecasts for the test set. Below, we636compute the MAE of the model and compare it to the MAE of naive forecasts.637The naive forecasts are the last value of the speed for each node.638"""639640x_test, y = next(test_dataset.as_numpy_iterator())641y_pred = model.predict(x_test)642plt.figure(figsize=(18, 6))643plt.plot(y[:, 0, 0])644plt.plot(y_pred[:, 0, 0])645plt.legend(["actual", "forecast"])646647naive_mse, model_mse = (648np.square(x_test[:, -1, :, 0] - y[:, 0, :]).mean(),649np.square(y_pred[:, 0, :] - y[:, 0, :]).mean(),650)651print(f"naive MAE: {naive_mse}, model MAE: {model_mse}")652653"""654Of course, the goal here is to demonstrate the method,655not to achieve the best performance. To improve the656model's accuracy, all model hyperparameters should be tuned carefully. In addition,657several of the `LSTMGC` blocks can be stacked to increase the representation power658of the model.659"""660661662