Path: blob/master/examples/timeseries/timeseries_weather_forecasting.py
3507 views
"""1Title: Timeseries forecasting for weather prediction2Authors: [Prabhanshu Attri](https://prabhanshu.com/github), [Yashika Sharma](https://github.com/yashika51), [Kristi Takach](https://github.com/ktakattack), [Falak Shah](https://github.com/falaktheoptimist)3Date created: 2020/06/234Last modified: 2023/11/225Description: This notebook demonstrates how to do timeseries forecasting using a LSTM model.6Accelerator: GPU7"""89"""10## Setup11"""1213import pandas as pd14import matplotlib.pyplot as plt15import keras1617"""18## Climate Data Time-Series1920We will be using Jena Climate dataset recorded by the21[Max Planck Institute for Biogeochemistry](https://www.bgc-jena.mpg.de/wetter/).22The dataset consists of 14 features such as temperature, pressure, humidity etc, recorded once per2310 minutes.2425**Location**: Weather Station, Max Planck Institute for Biogeochemistry26in Jena, Germany2728**Time-frame Considered**: Jan 10, 2009 - December 31, 2016293031The table below shows the column names, their value formats, and their description.3233Index| Features |Format |Description34-----|---------------|-------------------|-----------------------351 |Date Time |01.01.2009 00:10:00|Date-time reference362 |p (mbar) |996.52 |The pascal SI derived unit of pressure used to quantify internal pressure. Meteorological reports typically state atmospheric pressure in millibars.373 |T (degC) |-8.02 |Temperature in Celsius384 |Tpot (K) |265.4 |Temperature in Kelvin395 |Tdew (degC) |-8.9 |Temperature in Celsius relative to humidity. Dew Point is a measure of the absolute amount of water in the air, the DP is the temperature at which the air cannot hold all the moisture in it and water condenses.406 |rh (%) |93.3 |Relative Humidity is a measure of how saturated the air is with water vapor, the %RH determines the amount of water contained within collection objects.417 |VPmax (mbar) |3.33 |Saturation vapor pressure428 |VPact (mbar) |3.11 |Vapor pressure439 |VPdef (mbar) |0.22 |Vapor pressure deficit4410 |sh (g/kg) |1.94 |Specific humidity4511 |H2OC (mmol/mol)|3.12 |Water vapor concentration4612 |rho (g/m ** 3) |1307.75 |Airtight4713 |wv (m/s) |1.03 |Wind speed4814 |max. wv (m/s) |1.75 |Maximum wind speed4915 |wd (deg) |152.3 |Wind direction in degrees50"""5152from zipfile import ZipFile5354uri = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"55zip_path = keras.utils.get_file(origin=uri, fname="jena_climate_2009_2016.csv.zip")56zip_file = ZipFile(zip_path)57zip_file.extractall()58csv_path = "jena_climate_2009_2016.csv"5960df = pd.read_csv(csv_path)6162"""63## Raw Data Visualization6465To give us a sense of the data we are working with, each feature has been plotted below.66This shows the distinct pattern of each feature over the time period from 2009 to 2016.67It also shows where anomalies are present, which will be addressed during normalization.68"""6970titles = [71"Pressure",72"Temperature",73"Temperature in Kelvin",74"Temperature (dew point)",75"Relative Humidity",76"Saturation vapor pressure",77"Vapor pressure",78"Vapor pressure deficit",79"Specific humidity",80"Water vapor concentration",81"Airtight",82"Wind speed",83"Maximum wind speed",84"Wind direction in degrees",85]8687feature_keys = [88"p (mbar)",89"T (degC)",90"Tpot (K)",91"Tdew (degC)",92"rh (%)",93"VPmax (mbar)",94"VPact (mbar)",95"VPdef (mbar)",96"sh (g/kg)",97"H2OC (mmol/mol)",98"rho (g/m**3)",99"wv (m/s)",100"max. wv (m/s)",101"wd (deg)",102]103104colors = [105"blue",106"orange",107"green",108"red",109"purple",110"brown",111"pink",112"gray",113"olive",114"cyan",115]116117date_time_key = "Date Time"118119120def show_raw_visualization(data):121time_data = data[date_time_key]122fig, axes = plt.subplots(123nrows=7, ncols=2, figsize=(15, 20), dpi=80, facecolor="w", edgecolor="k"124)125for i in range(len(feature_keys)):126key = feature_keys[i]127c = colors[i % (len(colors))]128t_data = data[key]129t_data.index = time_data130t_data.head()131ax = t_data.plot(132ax=axes[i // 2, i % 2],133color=c,134title="{} - {}".format(titles[i], key),135rot=25,136)137ax.legend([titles[i]])138plt.tight_layout()139140141show_raw_visualization(df)142143144"""145## Data Preprocessing146147Here we are picking ~300,000 data points for training. Observation is recorded every14810 mins, that means 6 times per hour. We will resample one point per hour since no149drastic change is expected within 60 minutes. We do this via the `sampling_rate`150argument in `timeseries_dataset_from_array` utility.151152We are tracking data from past 720 timestamps (720/6=120 hours). This data will be153used to predict the temperature after 72 timestamps (72/6=12 hours).154155Since every feature has values with156varying ranges, we do normalization to confine feature values to a range of `[0, 1]` before157training a neural network.158We do this by subtracting the mean and dividing by the standard deviation of each feature.15916071.5 % of the data will be used to train the model, i.e. 300,693 rows. `split_fraction` can161be changed to alter this percentage.162163The model is shown data for first 5 days i.e. 720 observations, that are sampled every164hour. The temperature after 72 (12 hours * 6 observation per hour) observation will be165used as a label.166"""167168split_fraction = 0.715169train_split = int(split_fraction * int(df.shape[0]))170step = 6171172past = 720173future = 72174learning_rate = 0.001175batch_size = 256176epochs = 10177178179def normalize(data, train_split):180data_mean = data[:train_split].mean(axis=0)181data_std = data[:train_split].std(axis=0)182return (data - data_mean) / data_std183184185"""186We can see from the correlation heatmap, few parameters like Relative Humidity and187Specific Humidity are redundant. Hence we will be using select features, not all.188"""189190print(191"The selected parameters are:",192", ".join([titles[i] for i in [0, 1, 5, 7, 8, 10, 11]]),193)194selected_features = [feature_keys[i] for i in [0, 1, 5, 7, 8, 10, 11]]195features = df[selected_features]196features.index = df[date_time_key]197features.head()198199features = normalize(features.values, train_split)200features = pd.DataFrame(features)201features.head()202203train_data = features.loc[0 : train_split - 1]204val_data = features.loc[train_split:]205206"""207# Training dataset208209The training dataset labels starts from the 792nd observation (720 + 72).210"""211212start = past + future213end = start + train_split214215x_train = train_data[[i for i in range(7)]].values216y_train = features.iloc[start:end][[1]]217218sequence_length = int(past / step)219220"""221The `timeseries_dataset_from_array` function takes in a sequence of data-points gathered at222equal intervals, along with time series parameters such as length of the223sequences/windows, spacing between two sequence/windows, etc., to produce batches of224sub-timeseries inputs and targets sampled from the main timeseries.225"""226227dataset_train = keras.preprocessing.timeseries_dataset_from_array(228x_train,229y_train,230sequence_length=sequence_length,231sampling_rate=step,232batch_size=batch_size,233)234235"""236## Validation dataset237238The validation dataset must not contain the last 792 rows as we won't have label data for239those records, hence 792 must be subtracted from the end of the data.240241The validation label dataset must start from 792 after train_split, hence we must add242past + future (792) to label_start.243"""244245x_end = len(val_data) - past - future246247label_start = train_split + past + future248249x_val = val_data.iloc[:x_end][[i for i in range(7)]].values250y_val = features.iloc[label_start:][[1]]251252dataset_val = keras.preprocessing.timeseries_dataset_from_array(253x_val,254y_val,255sequence_length=sequence_length,256sampling_rate=step,257batch_size=batch_size,258)259260261for batch in dataset_train.take(1):262inputs, targets = batch263264print("Input shape:", inputs.numpy().shape)265print("Target shape:", targets.numpy().shape)266267"""268## Training269"""270271inputs = keras.layers.Input(shape=(inputs.shape[1], inputs.shape[2]))272lstm_out = keras.layers.LSTM(32)(inputs)273outputs = keras.layers.Dense(1)(lstm_out)274275model = keras.Model(inputs=inputs, outputs=outputs)276model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss="mse")277model.summary()278279"""280We'll use the `ModelCheckpoint` callback to regularly save checkpoints, and281the `EarlyStopping` callback to interrupt training when the validation loss282is not longer improving.283"""284285path_checkpoint = "model_checkpoint.weights.h5"286es_callback = keras.callbacks.EarlyStopping(monitor="val_loss", min_delta=0, patience=5)287288modelckpt_callback = keras.callbacks.ModelCheckpoint(289monitor="val_loss",290filepath=path_checkpoint,291verbose=1,292save_weights_only=True,293save_best_only=True,294)295296history = model.fit(297dataset_train,298epochs=epochs,299validation_data=dataset_val,300callbacks=[es_callback, modelckpt_callback],301)302303"""304We can visualize the loss with the function below. After one point, the loss stops305decreasing.306"""307308309def visualize_loss(history, title):310loss = history.history["loss"]311val_loss = history.history["val_loss"]312epochs = range(len(loss))313plt.figure()314plt.plot(epochs, loss, "b", label="Training loss")315plt.plot(epochs, val_loss, "r", label="Validation loss")316plt.title(title)317plt.xlabel("Epochs")318plt.ylabel("Loss")319plt.legend()320plt.show()321322323visualize_loss(history, "Training and Validation Loss")324325"""326## Prediction327328The trained model above is now able to make predictions for 5 sets of values from329validation set.330"""331332333def show_plot(plot_data, delta, title):334labels = ["History", "True Future", "Model Prediction"]335marker = [".-", "rx", "go"]336time_steps = list(range(-(plot_data[0].shape[0]), 0))337if delta:338future = delta339else:340future = 0341342plt.title(title)343for i, val in enumerate(plot_data):344if i:345plt.plot(future, plot_data[i], marker[i], markersize=10, label=labels[i])346else:347plt.plot(time_steps, plot_data[i].flatten(), marker[i], label=labels[i])348plt.legend()349plt.xlim([time_steps[0], (future + 5) * 2])350plt.xlabel("Time-Step")351plt.show()352return353354355for x, y in dataset_val.take(5):356show_plot(357[x[0][:, 1].numpy(), y[0].numpy(), model.predict(x)[0]],35812,359"Single Step Prediction",360)361362363