Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/conv_lstm.py
3507 views
1
"""
2
Title: Next-Frame Video Prediction with Convolutional LSTMs
3
Author: [Amogh Joshi](https://github.com/amogh7joshi)
4
Date created: 2021/06/02
5
Last modified: 2023/11/10
6
Description: How to build and train a convolutional LSTM model for next-frame video prediction.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The
14
[Convolutional LSTM](https://papers.nips.cc/paper/2015/file/07563a3fe3bbe7e3ba84431ad9d055af-Paper.pdf)
15
architectures bring together time series processing and computer vision by
16
introducing a convolutional recurrent cell in a LSTM layer. In this example, we will explore the
17
Convolutional LSTM model in an application to next-frame prediction, the process
18
of predicting what video frames come next given a series of past frames.
19
"""
20
21
"""
22
## Setup
23
"""
24
25
import numpy as np
26
import matplotlib.pyplot as plt
27
28
import keras
29
from keras import layers
30
31
import io
32
import imageio
33
from IPython.display import Image, display
34
from ipywidgets import widgets, Layout, HBox
35
36
"""
37
## Dataset Construction
38
39
For this example, we will be using the
40
[Moving MNIST](http://www.cs.toronto.edu/~nitish/unsupervised_video/)
41
dataset.
42
43
We will download the dataset and then construct and
44
preprocess training and validation sets.
45
46
For next-frame prediction, our model will be using a previous frame,
47
which we'll call `f_n`, to predict a new frame, called `f_(n + 1)`.
48
To allow the model to create these predictions, we'll need to process
49
the data such that we have "shifted" inputs and outputs, where the
50
input data is frame `x_n`, being used to predict frame `y_(n + 1)`.
51
"""
52
53
# Download and load the dataset.
54
fpath = keras.utils.get_file(
55
"moving_mnist.npy",
56
"http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy",
57
)
58
dataset = np.load(fpath)
59
60
# Swap the axes representing the number of frames and number of data samples.
61
dataset = np.swapaxes(dataset, 0, 1)
62
# We'll pick out 1000 of the 10000 total examples and use those.
63
dataset = dataset[:1000, ...]
64
# Add a channel dimension since the images are grayscale.
65
dataset = np.expand_dims(dataset, axis=-1)
66
67
# Split into train and validation sets using indexing to optimize memory.
68
indexes = np.arange(dataset.shape[0])
69
np.random.shuffle(indexes)
70
train_index = indexes[: int(0.9 * dataset.shape[0])]
71
val_index = indexes[int(0.9 * dataset.shape[0]) :]
72
train_dataset = dataset[train_index]
73
val_dataset = dataset[val_index]
74
75
# Normalize the data to the 0-1 range.
76
train_dataset = train_dataset / 255
77
val_dataset = val_dataset / 255
78
79
80
# We'll define a helper function to shift the frames, where
81
# `x` is frames 0 to n - 1, and `y` is frames 1 to n.
82
def create_shifted_frames(data):
83
x = data[:, 0 : data.shape[1] - 1, :, :]
84
y = data[:, 1 : data.shape[1], :, :]
85
return x, y
86
87
88
# Apply the processing function to the datasets.
89
x_train, y_train = create_shifted_frames(train_dataset)
90
x_val, y_val = create_shifted_frames(val_dataset)
91
92
# Inspect the dataset.
93
print("Training Dataset Shapes: " + str(x_train.shape) + ", " + str(y_train.shape))
94
print("Validation Dataset Shapes: " + str(x_val.shape) + ", " + str(y_val.shape))
95
96
"""
97
## Data Visualization
98
99
Our data consists of sequences of frames, each of which
100
are used to predict the upcoming frame. Let's take a look
101
at some of these sequential frames.
102
"""
103
104
# Construct a figure on which we will visualize the images.
105
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
106
107
# Plot each of the sequential images for one random data example.
108
data_choice = np.random.choice(range(len(train_dataset)), size=1)[0]
109
for idx, ax in enumerate(axes.flat):
110
ax.imshow(np.squeeze(train_dataset[data_choice][idx]), cmap="gray")
111
ax.set_title(f"Frame {idx + 1}")
112
ax.axis("off")
113
114
# Print information and display the figure.
115
print(f"Displaying frames for example {data_choice}.")
116
plt.show()
117
118
"""
119
## Model Construction
120
121
To build a Convolutional LSTM model, we will use the
122
`ConvLSTM2D` layer, which will accept inputs of shape
123
`(batch_size, num_frames, width, height, channels)`, and return
124
a prediction movie of the same shape.
125
"""
126
127
# Construct the input layer with no definite frame size.
128
inp = layers.Input(shape=(None, *x_train.shape[2:]))
129
130
# We will construct 3 `ConvLSTM2D` layers with batch normalization,
131
# followed by a `Conv3D` layer for the spatiotemporal outputs.
132
x = layers.ConvLSTM2D(
133
filters=64,
134
kernel_size=(5, 5),
135
padding="same",
136
return_sequences=True,
137
activation="relu",
138
)(inp)
139
x = layers.BatchNormalization()(x)
140
x = layers.ConvLSTM2D(
141
filters=64,
142
kernel_size=(3, 3),
143
padding="same",
144
return_sequences=True,
145
activation="relu",
146
)(x)
147
x = layers.BatchNormalization()(x)
148
x = layers.ConvLSTM2D(
149
filters=64,
150
kernel_size=(1, 1),
151
padding="same",
152
return_sequences=True,
153
activation="relu",
154
)(x)
155
x = layers.Conv3D(
156
filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
157
)(x)
158
159
# Next, we will build the complete model and compile it.
160
model = keras.models.Model(inp, x)
161
model.compile(
162
loss=keras.losses.binary_crossentropy,
163
optimizer=keras.optimizers.Adam(),
164
)
165
166
"""
167
## Model Training
168
169
With our model and data constructed, we can now train the model.
170
"""
171
172
# Define some callbacks to improve training.
173
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
174
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)
175
176
# Define modifiable training hyperparameters.
177
epochs = 20
178
batch_size = 5
179
180
# Fit the model to the training data.
181
model.fit(
182
x_train,
183
y_train,
184
batch_size=batch_size,
185
epochs=epochs,
186
validation_data=(x_val, y_val),
187
callbacks=[early_stopping, reduce_lr],
188
)
189
190
"""
191
## Frame Prediction Visualizations
192
193
With our model now constructed and trained, we can generate
194
some example frame predictions based on a new video.
195
196
We'll pick a random example from the validation set and
197
then choose the first ten frames from them. From there, we can
198
allow the model to predict 10 new frames, which we can compare
199
to the ground truth frame predictions.
200
"""
201
202
# Select a random example from the validation dataset.
203
example = val_dataset[np.random.choice(range(len(val_dataset)), size=1)[0]]
204
205
# Pick the first/last ten frames from the example.
206
frames = example[:10, ...]
207
original_frames = example[10:, ...]
208
209
# Predict a new set of 10 frames.
210
for _ in range(10):
211
# Extract the model's prediction and post-process it.
212
new_prediction = model.predict(np.expand_dims(frames, axis=0))
213
new_prediction = np.squeeze(new_prediction, axis=0)
214
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
215
216
# Extend the set of prediction frames.
217
frames = np.concatenate((frames, predicted_frame), axis=0)
218
219
# Construct a figure for the original and new frames.
220
fig, axes = plt.subplots(2, 10, figsize=(20, 4))
221
222
# Plot the original frames.
223
for idx, ax in enumerate(axes[0]):
224
ax.imshow(np.squeeze(original_frames[idx]), cmap="gray")
225
ax.set_title(f"Frame {idx + 11}")
226
ax.axis("off")
227
228
# Plot the new frames.
229
new_frames = frames[10:, ...]
230
for idx, ax in enumerate(axes[1]):
231
ax.imshow(np.squeeze(new_frames[idx]), cmap="gray")
232
ax.set_title(f"Frame {idx + 11}")
233
ax.axis("off")
234
235
# Display the figure.
236
plt.show()
237
238
"""
239
## Predicted Videos
240
241
Finally, we'll pick a few examples from the validation set
242
and construct some GIFs with them to see the model's
243
predicted videos.
244
245
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conv-lstm)
246
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conv-lstm).
247
"""
248
249
# Select a few random examples from the dataset.
250
examples = val_dataset[np.random.choice(range(len(val_dataset)), size=5)]
251
252
# Iterate over the examples and predict the frames.
253
predicted_videos = []
254
for example in examples:
255
# Pick the first/last ten frames from the example.
256
frames = example[:10, ...]
257
original_frames = example[10:, ...]
258
new_predictions = np.zeros(shape=(10, *frames[0].shape))
259
260
# Predict a new set of 10 frames.
261
for i in range(10):
262
# Extract the model's prediction and post-process it.
263
frames = example[: 10 + i + 1, ...]
264
new_prediction = model.predict(np.expand_dims(frames, axis=0))
265
new_prediction = np.squeeze(new_prediction, axis=0)
266
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
267
268
# Extend the set of prediction frames.
269
new_predictions[i] = predicted_frame
270
271
# Create and save GIFs for each of the ground truth/prediction images.
272
for frame_set in [original_frames, new_predictions]:
273
# Construct a GIF from the selected video frames.
274
current_frames = np.squeeze(frame_set)
275
current_frames = current_frames[..., np.newaxis] * np.ones(3)
276
current_frames = (current_frames * 255).astype(np.uint8)
277
current_frames = list(current_frames)
278
279
# Construct a GIF from the frames.
280
with io.BytesIO() as gif:
281
imageio.mimsave(gif, current_frames, "GIF", duration=200)
282
predicted_videos.append(gif.getvalue())
283
284
# Display the videos.
285
print(" Truth\tPrediction")
286
for i in range(0, len(predicted_videos), 2):
287
# Construct and display an `HBox` with the ground truth and prediction.
288
box = HBox(
289
[
290
widgets.Image(value=predicted_videos[i]),
291
widgets.Image(value=predicted_videos[i + 1]),
292
]
293
)
294
display(box)
295
296