Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/keypoint_detection.py
3507 views
1
"""
2
Title: Keypoint Detection with Transfer Learning
3
Author: [Sayak Paul](https://twitter.com/RisingSayak), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
4
Date created: 2021/05/02
5
Last modified: 2023/07/19
6
Description: Training a keypoint detector with data augmentation and transfer learning.
7
Accelerator: GPU
8
"""
9
10
"""
11
Keypoint detection consists of locating key object parts. For example, the key parts
12
of our faces include nose tips, eyebrows, eye corners, and so on. These parts help to
13
represent the underlying object in a feature-rich manner. Keypoint detection has
14
applications that include pose estimation, face detection, etc.
15
16
In this example, we will build a keypoint detector using the
17
[StanfordExtra dataset](https://github.com/benjiebob/StanfordExtra),
18
using transfer learning. This example requires TensorFlow 2.4 or higher,
19
as well as [`imgaug`](https://imgaug.readthedocs.io/) library,
20
which can be installed using the following command:
21
"""
22
23
"""shell
24
pip install -q -U imgaug
25
"""
26
27
"""
28
## Data collection
29
"""
30
31
"""
32
The StanfordExtra dataset contains 12,000 images of dogs together with keypoints and
33
segmentation maps. It is developed from the [Stanford dogs dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/).
34
It can be downloaded with the command below:
35
"""
36
37
"""shell
38
wget -q http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
39
"""
40
41
"""
42
Annotations are provided as a single JSON file in the StanfordExtra dataset and one needs
43
to fill [this form](https://forms.gle/sRtbicgxsWvRtRmUA) to get access to it. The
44
authors explicitly instruct users not to share the JSON file, and this example respects this wish:
45
you should obtain the JSON file yourself.
46
47
The JSON file is expected to be locally available as `stanfordextra_v12.zip`.
48
49
After the files are downloaded, we can extract the archives.
50
"""
51
52
"""shell
53
tar xf images.tar
54
unzip -qq ~/stanfordextra_v12.zip
55
"""
56
57
"""
58
## Imports
59
"""
60
from keras import layers
61
import keras
62
63
from imgaug.augmentables.kps import KeypointsOnImage
64
from imgaug.augmentables.kps import Keypoint
65
import imgaug.augmenters as iaa
66
67
from PIL import Image
68
from sklearn.model_selection import train_test_split
69
from matplotlib import pyplot as plt
70
import pandas as pd
71
import numpy as np
72
import json
73
import os
74
75
"""
76
## Define hyperparameters
77
"""
78
79
IMG_SIZE = 224
80
BATCH_SIZE = 64
81
EPOCHS = 5
82
NUM_KEYPOINTS = 24 * 2 # 24 pairs each having x and y coordinates
83
84
"""
85
## Load data
86
87
The authors also provide a metadata file that specifies additional information about the
88
keypoints, like color information, animal pose name, etc. We will load this file in a `pandas`
89
dataframe to extract information for visualization purposes.
90
"""
91
92
IMG_DIR = "Images"
93
JSON = "StanfordExtra_V12/StanfordExtra_v12.json"
94
KEYPOINT_DEF = (
95
"https://github.com/benjiebob/StanfordExtra/raw/master/keypoint_definitions.csv"
96
)
97
98
# Load the ground-truth annotations.
99
with open(JSON) as infile:
100
json_data = json.load(infile)
101
102
# Set up a dictionary, mapping all the ground-truth information
103
# with respect to the path of the image.
104
json_dict = {i["img_path"]: i for i in json_data}
105
106
"""
107
A single entry of `json_dict` looks like the following:
108
109
```
110
'n02085782-Japanese_spaniel/n02085782_2886.jpg':
111
{'img_bbox': [205, 20, 116, 201],
112
'img_height': 272,
113
'img_path': 'n02085782-Japanese_spaniel/n02085782_2886.jpg',
114
'img_width': 350,
115
'is_multiple_dogs': False,
116
'joints': [[108.66666666666667, 252.0, 1],
117
[147.66666666666666, 229.0, 1],
118
[163.5, 208.5, 1],
119
[0, 0, 0],
120
[0, 0, 0],
121
[0, 0, 0],
122
[54.0, 244.0, 1],
123
[77.33333333333333, 225.33333333333334, 1],
124
[79.0, 196.5, 1],
125
[0, 0, 0],
126
[0, 0, 0],
127
[0, 0, 0],
128
[0, 0, 0],
129
[0, 0, 0],
130
[150.66666666666666, 86.66666666666667, 1],
131
[88.66666666666667, 73.0, 1],
132
[116.0, 106.33333333333333, 1],
133
[109.0, 123.33333333333333, 1],
134
[0, 0, 0],
135
[0, 0, 0],
136
[0, 0, 0],
137
[0, 0, 0],
138
[0, 0, 0],
139
[0, 0, 0]],
140
'seg': ...}
141
```
142
"""
143
144
"""
145
In this example, the keys we are interested in are:
146
147
* `img_path`
148
* `joints`
149
150
There are a total of 24 entries present inside `joints`. Each entry has 3 values:
151
152
* x-coordinate
153
* y-coordinate
154
* visibility flag of the keypoints (1 indicates visibility and 0 indicates non-visibility)
155
156
As we can see `joints` contain multiple `[0, 0, 0]` entries which denote that those
157
keypoints were not labeled. In this example, we will consider both non-visible as well as
158
unlabeled keypoints in order to allow mini-batch learning.
159
"""
160
161
# Load the metdata definition file and preview it.
162
keypoint_def = pd.read_csv(KEYPOINT_DEF)
163
keypoint_def.head()
164
165
# Extract the colours and labels.
166
colours = keypoint_def["Hex colour"].values.tolist()
167
colours = ["#" + colour for colour in colours]
168
labels = keypoint_def["Name"].values.tolist()
169
170
171
# Utility for reading an image and for getting its annotations.
172
def get_dog(name):
173
data = json_dict[name]
174
img_data = plt.imread(os.path.join(IMG_DIR, data["img_path"]))
175
# If the image is RGBA convert it to RGB.
176
if img_data.shape[-1] == 4:
177
img_data = img_data.astype(np.uint8)
178
img_data = Image.fromarray(img_data)
179
img_data = np.array(img_data.convert("RGB"))
180
data["img_data"] = img_data
181
182
return data
183
184
185
"""
186
## Visualize data
187
188
Now, we write a utility function to visualize the images and their keypoints.
189
"""
190
191
192
# Parts of this code come from here:
193
# https://github.com/benjiebob/StanfordExtra/blob/master/demo.ipynb
194
def visualize_keypoints(images, keypoints):
195
fig, axes = plt.subplots(nrows=len(images), ncols=2, figsize=(16, 12))
196
[ax.axis("off") for ax in np.ravel(axes)]
197
198
for (ax_orig, ax_all), image, current_keypoint in zip(axes, images, keypoints):
199
ax_orig.imshow(image)
200
ax_all.imshow(image)
201
202
# If the keypoints were formed by `imgaug` then the coordinates need
203
# to be iterated differently.
204
if isinstance(current_keypoint, KeypointsOnImage):
205
for idx, kp in enumerate(current_keypoint.keypoints):
206
ax_all.scatter(
207
[kp.x],
208
[kp.y],
209
c=colours[idx],
210
marker="x",
211
s=50,
212
linewidths=5,
213
)
214
else:
215
current_keypoint = np.array(current_keypoint)
216
# Since the last entry is the visibility flag, we discard it.
217
current_keypoint = current_keypoint[:, :2]
218
for idx, (x, y) in enumerate(current_keypoint):
219
ax_all.scatter([x], [y], c=colours[idx], marker="x", s=50, linewidths=5)
220
221
plt.tight_layout(pad=2.0)
222
plt.show()
223
224
225
# Select four samples randomly for visualization.
226
samples = list(json_dict.keys())
227
num_samples = 4
228
selected_samples = np.random.choice(samples, num_samples, replace=False)
229
230
images, keypoints = [], []
231
232
for sample in selected_samples:
233
data = get_dog(sample)
234
image = data["img_data"]
235
keypoint = data["joints"]
236
237
images.append(image)
238
keypoints.append(keypoint)
239
240
visualize_keypoints(images, keypoints)
241
242
"""
243
The plots show that we have images of non-uniform sizes, which is expected in most
244
real-world scenarios. However, if we resize these images to have a uniform shape (for
245
instance (224 x 224)) their ground-truth annotations will also be affected. The same
246
applies if we apply any geometric transformation (horizontal flip, for e.g.) to an image.
247
Fortunately, `imgaug` provides utilities that can handle this issue.
248
In the next section, we will write a data generator inheriting the
249
[`keras.utils.Sequence`](https://keras.io/api/utils/python_utils/#sequence-class) class
250
that applies data augmentation on batches of data using `imgaug`.
251
"""
252
253
"""
254
## Prepare data generator
255
"""
256
257
258
class KeyPointsDataset(keras.utils.PyDataset):
259
def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True, **kwargs):
260
super().__init__(**kwargs)
261
self.image_keys = image_keys
262
self.aug = aug
263
self.batch_size = batch_size
264
self.train = train
265
self.on_epoch_end()
266
267
def __len__(self):
268
return len(self.image_keys) // self.batch_size
269
270
def on_epoch_end(self):
271
self.indexes = np.arange(len(self.image_keys))
272
if self.train:
273
np.random.shuffle(self.indexes)
274
275
def __getitem__(self, index):
276
indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
277
image_keys_temp = [self.image_keys[k] for k in indexes]
278
(images, keypoints) = self.__data_generation(image_keys_temp)
279
280
return (images, keypoints)
281
282
def __data_generation(self, image_keys_temp):
283
batch_images = np.empty((self.batch_size, IMG_SIZE, IMG_SIZE, 3), dtype="int")
284
batch_keypoints = np.empty(
285
(self.batch_size, 1, 1, NUM_KEYPOINTS), dtype="float32"
286
)
287
288
for i, key in enumerate(image_keys_temp):
289
data = get_dog(key)
290
current_keypoint = np.array(data["joints"])[:, :2]
291
kps = []
292
293
# To apply our data augmentation pipeline, we first need to
294
# form Keypoint objects with the original coordinates.
295
for j in range(0, len(current_keypoint)):
296
kps.append(Keypoint(x=current_keypoint[j][0], y=current_keypoint[j][1]))
297
298
# We then project the original image and its keypoint coordinates.
299
current_image = data["img_data"]
300
kps_obj = KeypointsOnImage(kps, shape=current_image.shape)
301
302
# Apply the augmentation pipeline.
303
(new_image, new_kps_obj) = self.aug(image=current_image, keypoints=kps_obj)
304
batch_images[i,] = new_image
305
306
# Parse the coordinates from the new keypoint object.
307
kp_temp = []
308
for keypoint in new_kps_obj:
309
kp_temp.append(np.nan_to_num(keypoint.x))
310
kp_temp.append(np.nan_to_num(keypoint.y))
311
312
# More on why this reshaping later.
313
batch_keypoints[i,] = np.array(kp_temp).reshape(1, 1, 24 * 2)
314
315
# Scale the coordinates to [0, 1] range.
316
batch_keypoints = batch_keypoints / IMG_SIZE
317
318
return (batch_images, batch_keypoints)
319
320
321
"""
322
To know more about how to operate with keypoints in `imgaug` check out
323
[this document](https://imgaug.readthedocs.io/en/latest/source/examples_keypoints.html).
324
"""
325
326
"""
327
## Define augmentation transforms
328
"""
329
330
train_aug = iaa.Sequential(
331
[
332
iaa.Resize(IMG_SIZE, interpolation="linear"),
333
iaa.Fliplr(0.3),
334
# `Sometimes()` applies a function randomly to the inputs with
335
# a given probability (0.3, in this case).
336
iaa.Sometimes(0.3, iaa.Affine(rotate=10, scale=(0.5, 0.7))),
337
]
338
)
339
340
test_aug = iaa.Sequential([iaa.Resize(IMG_SIZE, interpolation="linear")])
341
342
"""
343
## Create training and validation splits
344
"""
345
346
np.random.shuffle(samples)
347
train_keys, validation_keys = (
348
samples[int(len(samples) * 0.15) :],
349
samples[: int(len(samples) * 0.15)],
350
)
351
352
353
"""
354
## Data generator investigation
355
"""
356
357
train_dataset = KeyPointsDataset(
358
train_keys, train_aug, workers=2, use_multiprocessing=True
359
)
360
validation_dataset = KeyPointsDataset(
361
validation_keys, test_aug, train=False, workers=2, use_multiprocessing=True
362
)
363
364
print(f"Total batches in training set: {len(train_dataset)}")
365
print(f"Total batches in validation set: {len(validation_dataset)}")
366
367
sample_images, sample_keypoints = next(iter(train_dataset))
368
assert sample_keypoints.max() == 1.0
369
assert sample_keypoints.min() == 0.0
370
371
sample_keypoints = sample_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE
372
visualize_keypoints(sample_images[:4], sample_keypoints)
373
374
"""
375
## Model building
376
377
The [Stanford dogs dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/) (on which
378
the StanfordExtra dataset is based) was built using the [ImageNet-1k dataset](http://image-net.org/).
379
So, it is likely that the models pretrained on the ImageNet-1k dataset would be useful
380
for this task. We will use a MobileNetV2 pre-trained on this dataset as a backbone to
381
extract meaningful features from the images and then pass those to a custom regression
382
head for predicting coordinates.
383
"""
384
385
386
def get_model():
387
# Load the pre-trained weights of MobileNetV2 and freeze the weights
388
backbone = keras.applications.MobileNetV2(
389
weights="imagenet",
390
include_top=False,
391
input_shape=(IMG_SIZE, IMG_SIZE, 3),
392
)
393
backbone.trainable = False
394
395
inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))
396
x = keras.applications.mobilenet_v2.preprocess_input(inputs)
397
x = backbone(x)
398
x = layers.Dropout(0.3)(x)
399
x = layers.SeparableConv2D(
400
NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"
401
)(x)
402
outputs = layers.SeparableConv2D(
403
NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"
404
)(x)
405
406
return keras.Model(inputs, outputs, name="keypoint_detector")
407
408
409
"""
410
Our custom network is fully-convolutional which makes it more parameter-friendly than the
411
same version of the network having fully-connected dense layers.
412
"""
413
414
get_model().summary()
415
416
"""
417
Notice the output shape of the network: `(None, 1, 1, 48)`. This is why we have reshaped
418
the coordinates as: `batch_keypoints[i, :] = np.array(kp_temp).reshape(1, 1, 24 * 2)`.
419
"""
420
421
"""
422
## Model compilation and training
423
424
For this example, we will train the network only for five epochs.
425
"""
426
427
model = get_model()
428
model.compile(loss="mse", optimizer=keras.optimizers.Adam(1e-4))
429
model.fit(train_dataset, validation_data=validation_dataset, epochs=EPOCHS)
430
431
"""
432
## Make predictions and visualize them
433
"""
434
435
sample_val_images, sample_val_keypoints = next(iter(validation_dataset))
436
sample_val_images = sample_val_images[:4]
437
sample_val_keypoints = sample_val_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE
438
predictions = model.predict(sample_val_images).reshape(-1, 24, 2) * IMG_SIZE
439
440
# Ground-truth
441
visualize_keypoints(sample_val_images, sample_val_keypoints)
442
443
# Predictions
444
visualize_keypoints(sample_val_images, predictions)
445
446
"""
447
Predictions will likely improve with more training.
448
"""
449
450
"""
451
## Going further
452
453
* Try using other augmentation transforms from `imgaug` to investigate how that changes
454
the results.
455
* Here, we transferred the features from the pre-trained network linearly that is we did
456
not [fine-tune](https://keras.io/guides/transfer_learning/) it. You are encouraged to fine-tune it on this task and see if that
457
improves the performance. You can also try different architectures and see how they
458
affect the final performance.
459
"""
460
461