Path: blob/master/examples/vision/keypoint_detection.py
3507 views
"""1Title: Keypoint Detection with Transfer Learning2Author: [Sayak Paul](https://twitter.com/RisingSayak), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)3Date created: 2021/05/024Last modified: 2023/07/195Description: Training a keypoint detector with data augmentation and transfer learning.6Accelerator: GPU7"""89"""10Keypoint detection consists of locating key object parts. For example, the key parts11of our faces include nose tips, eyebrows, eye corners, and so on. These parts help to12represent the underlying object in a feature-rich manner. Keypoint detection has13applications that include pose estimation, face detection, etc.1415In this example, we will build a keypoint detector using the16[StanfordExtra dataset](https://github.com/benjiebob/StanfordExtra),17using transfer learning. This example requires TensorFlow 2.4 or higher,18as well as [`imgaug`](https://imgaug.readthedocs.io/) library,19which can be installed using the following command:20"""2122"""shell23pip install -q -U imgaug24"""2526"""27## Data collection28"""2930"""31The StanfordExtra dataset contains 12,000 images of dogs together with keypoints and32segmentation maps. It is developed from the [Stanford dogs dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/).33It can be downloaded with the command below:34"""3536"""shell37wget -q http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar38"""3940"""41Annotations are provided as a single JSON file in the StanfordExtra dataset and one needs42to fill [this form](https://forms.gle/sRtbicgxsWvRtRmUA) to get access to it. The43authors explicitly instruct users not to share the JSON file, and this example respects this wish:44you should obtain the JSON file yourself.4546The JSON file is expected to be locally available as `stanfordextra_v12.zip`.4748After the files are downloaded, we can extract the archives.49"""5051"""shell52tar xf images.tar53unzip -qq ~/stanfordextra_v12.zip54"""5556"""57## Imports58"""59from keras import layers60import keras6162from imgaug.augmentables.kps import KeypointsOnImage63from imgaug.augmentables.kps import Keypoint64import imgaug.augmenters as iaa6566from PIL import Image67from sklearn.model_selection import train_test_split68from matplotlib import pyplot as plt69import pandas as pd70import numpy as np71import json72import os7374"""75## Define hyperparameters76"""7778IMG_SIZE = 22479BATCH_SIZE = 6480EPOCHS = 581NUM_KEYPOINTS = 24 * 2 # 24 pairs each having x and y coordinates8283"""84## Load data8586The authors also provide a metadata file that specifies additional information about the87keypoints, like color information, animal pose name, etc. We will load this file in a `pandas`88dataframe to extract information for visualization purposes.89"""9091IMG_DIR = "Images"92JSON = "StanfordExtra_V12/StanfordExtra_v12.json"93KEYPOINT_DEF = (94"https://github.com/benjiebob/StanfordExtra/raw/master/keypoint_definitions.csv"95)9697# Load the ground-truth annotations.98with open(JSON) as infile:99json_data = json.load(infile)100101# Set up a dictionary, mapping all the ground-truth information102# with respect to the path of the image.103json_dict = {i["img_path"]: i for i in json_data}104105"""106A single entry of `json_dict` looks like the following:107108```109'n02085782-Japanese_spaniel/n02085782_2886.jpg':110{'img_bbox': [205, 20, 116, 201],111'img_height': 272,112'img_path': 'n02085782-Japanese_spaniel/n02085782_2886.jpg',113'img_width': 350,114'is_multiple_dogs': False,115'joints': [[108.66666666666667, 252.0, 1],116[147.66666666666666, 229.0, 1],117[163.5, 208.5, 1],118[0, 0, 0],119[0, 0, 0],120[0, 0, 0],121[54.0, 244.0, 1],122[77.33333333333333, 225.33333333333334, 1],123[79.0, 196.5, 1],124[0, 0, 0],125[0, 0, 0],126[0, 0, 0],127[0, 0, 0],128[0, 0, 0],129[150.66666666666666, 86.66666666666667, 1],130[88.66666666666667, 73.0, 1],131[116.0, 106.33333333333333, 1],132[109.0, 123.33333333333333, 1],133[0, 0, 0],134[0, 0, 0],135[0, 0, 0],136[0, 0, 0],137[0, 0, 0],138[0, 0, 0]],139'seg': ...}140```141"""142143"""144In this example, the keys we are interested in are:145146* `img_path`147* `joints`148149There are a total of 24 entries present inside `joints`. Each entry has 3 values:150151* x-coordinate152* y-coordinate153* visibility flag of the keypoints (1 indicates visibility and 0 indicates non-visibility)154155As we can see `joints` contain multiple `[0, 0, 0]` entries which denote that those156keypoints were not labeled. In this example, we will consider both non-visible as well as157unlabeled keypoints in order to allow mini-batch learning.158"""159160# Load the metdata definition file and preview it.161keypoint_def = pd.read_csv(KEYPOINT_DEF)162keypoint_def.head()163164# Extract the colours and labels.165colours = keypoint_def["Hex colour"].values.tolist()166colours = ["#" + colour for colour in colours]167labels = keypoint_def["Name"].values.tolist()168169170# Utility for reading an image and for getting its annotations.171def get_dog(name):172data = json_dict[name]173img_data = plt.imread(os.path.join(IMG_DIR, data["img_path"]))174# If the image is RGBA convert it to RGB.175if img_data.shape[-1] == 4:176img_data = img_data.astype(np.uint8)177img_data = Image.fromarray(img_data)178img_data = np.array(img_data.convert("RGB"))179data["img_data"] = img_data180181return data182183184"""185## Visualize data186187Now, we write a utility function to visualize the images and their keypoints.188"""189190191# Parts of this code come from here:192# https://github.com/benjiebob/StanfordExtra/blob/master/demo.ipynb193def visualize_keypoints(images, keypoints):194fig, axes = plt.subplots(nrows=len(images), ncols=2, figsize=(16, 12))195[ax.axis("off") for ax in np.ravel(axes)]196197for (ax_orig, ax_all), image, current_keypoint in zip(axes, images, keypoints):198ax_orig.imshow(image)199ax_all.imshow(image)200201# If the keypoints were formed by `imgaug` then the coordinates need202# to be iterated differently.203if isinstance(current_keypoint, KeypointsOnImage):204for idx, kp in enumerate(current_keypoint.keypoints):205ax_all.scatter(206[kp.x],207[kp.y],208c=colours[idx],209marker="x",210s=50,211linewidths=5,212)213else:214current_keypoint = np.array(current_keypoint)215# Since the last entry is the visibility flag, we discard it.216current_keypoint = current_keypoint[:, :2]217for idx, (x, y) in enumerate(current_keypoint):218ax_all.scatter([x], [y], c=colours[idx], marker="x", s=50, linewidths=5)219220plt.tight_layout(pad=2.0)221plt.show()222223224# Select four samples randomly for visualization.225samples = list(json_dict.keys())226num_samples = 4227selected_samples = np.random.choice(samples, num_samples, replace=False)228229images, keypoints = [], []230231for sample in selected_samples:232data = get_dog(sample)233image = data["img_data"]234keypoint = data["joints"]235236images.append(image)237keypoints.append(keypoint)238239visualize_keypoints(images, keypoints)240241"""242The plots show that we have images of non-uniform sizes, which is expected in most243real-world scenarios. However, if we resize these images to have a uniform shape (for244instance (224 x 224)) their ground-truth annotations will also be affected. The same245applies if we apply any geometric transformation (horizontal flip, for e.g.) to an image.246Fortunately, `imgaug` provides utilities that can handle this issue.247In the next section, we will write a data generator inheriting the248[`keras.utils.Sequence`](https://keras.io/api/utils/python_utils/#sequence-class) class249that applies data augmentation on batches of data using `imgaug`.250"""251252"""253## Prepare data generator254"""255256257class KeyPointsDataset(keras.utils.PyDataset):258def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True, **kwargs):259super().__init__(**kwargs)260self.image_keys = image_keys261self.aug = aug262self.batch_size = batch_size263self.train = train264self.on_epoch_end()265266def __len__(self):267return len(self.image_keys) // self.batch_size268269def on_epoch_end(self):270self.indexes = np.arange(len(self.image_keys))271if self.train:272np.random.shuffle(self.indexes)273274def __getitem__(self, index):275indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]276image_keys_temp = [self.image_keys[k] for k in indexes]277(images, keypoints) = self.__data_generation(image_keys_temp)278279return (images, keypoints)280281def __data_generation(self, image_keys_temp):282batch_images = np.empty((self.batch_size, IMG_SIZE, IMG_SIZE, 3), dtype="int")283batch_keypoints = np.empty(284(self.batch_size, 1, 1, NUM_KEYPOINTS), dtype="float32"285)286287for i, key in enumerate(image_keys_temp):288data = get_dog(key)289current_keypoint = np.array(data["joints"])[:, :2]290kps = []291292# To apply our data augmentation pipeline, we first need to293# form Keypoint objects with the original coordinates.294for j in range(0, len(current_keypoint)):295kps.append(Keypoint(x=current_keypoint[j][0], y=current_keypoint[j][1]))296297# We then project the original image and its keypoint coordinates.298current_image = data["img_data"]299kps_obj = KeypointsOnImage(kps, shape=current_image.shape)300301# Apply the augmentation pipeline.302(new_image, new_kps_obj) = self.aug(image=current_image, keypoints=kps_obj)303batch_images[i,] = new_image304305# Parse the coordinates from the new keypoint object.306kp_temp = []307for keypoint in new_kps_obj:308kp_temp.append(np.nan_to_num(keypoint.x))309kp_temp.append(np.nan_to_num(keypoint.y))310311# More on why this reshaping later.312batch_keypoints[i,] = np.array(kp_temp).reshape(1, 1, 24 * 2)313314# Scale the coordinates to [0, 1] range.315batch_keypoints = batch_keypoints / IMG_SIZE316317return (batch_images, batch_keypoints)318319320"""321To know more about how to operate with keypoints in `imgaug` check out322[this document](https://imgaug.readthedocs.io/en/latest/source/examples_keypoints.html).323"""324325"""326## Define augmentation transforms327"""328329train_aug = iaa.Sequential(330[331iaa.Resize(IMG_SIZE, interpolation="linear"),332iaa.Fliplr(0.3),333# `Sometimes()` applies a function randomly to the inputs with334# a given probability (0.3, in this case).335iaa.Sometimes(0.3, iaa.Affine(rotate=10, scale=(0.5, 0.7))),336]337)338339test_aug = iaa.Sequential([iaa.Resize(IMG_SIZE, interpolation="linear")])340341"""342## Create training and validation splits343"""344345np.random.shuffle(samples)346train_keys, validation_keys = (347samples[int(len(samples) * 0.15) :],348samples[: int(len(samples) * 0.15)],349)350351352"""353## Data generator investigation354"""355356train_dataset = KeyPointsDataset(357train_keys, train_aug, workers=2, use_multiprocessing=True358)359validation_dataset = KeyPointsDataset(360validation_keys, test_aug, train=False, workers=2, use_multiprocessing=True361)362363print(f"Total batches in training set: {len(train_dataset)}")364print(f"Total batches in validation set: {len(validation_dataset)}")365366sample_images, sample_keypoints = next(iter(train_dataset))367assert sample_keypoints.max() == 1.0368assert sample_keypoints.min() == 0.0369370sample_keypoints = sample_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE371visualize_keypoints(sample_images[:4], sample_keypoints)372373"""374## Model building375376The [Stanford dogs dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/) (on which377the StanfordExtra dataset is based) was built using the [ImageNet-1k dataset](http://image-net.org/).378So, it is likely that the models pretrained on the ImageNet-1k dataset would be useful379for this task. We will use a MobileNetV2 pre-trained on this dataset as a backbone to380extract meaningful features from the images and then pass those to a custom regression381head for predicting coordinates.382"""383384385def get_model():386# Load the pre-trained weights of MobileNetV2 and freeze the weights387backbone = keras.applications.MobileNetV2(388weights="imagenet",389include_top=False,390input_shape=(IMG_SIZE, IMG_SIZE, 3),391)392backbone.trainable = False393394inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))395x = keras.applications.mobilenet_v2.preprocess_input(inputs)396x = backbone(x)397x = layers.Dropout(0.3)(x)398x = layers.SeparableConv2D(399NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"400)(x)401outputs = layers.SeparableConv2D(402NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"403)(x)404405return keras.Model(inputs, outputs, name="keypoint_detector")406407408"""409Our custom network is fully-convolutional which makes it more parameter-friendly than the410same version of the network having fully-connected dense layers.411"""412413get_model().summary()414415"""416Notice the output shape of the network: `(None, 1, 1, 48)`. This is why we have reshaped417the coordinates as: `batch_keypoints[i, :] = np.array(kp_temp).reshape(1, 1, 24 * 2)`.418"""419420"""421## Model compilation and training422423For this example, we will train the network only for five epochs.424"""425426model = get_model()427model.compile(loss="mse", optimizer=keras.optimizers.Adam(1e-4))428model.fit(train_dataset, validation_data=validation_dataset, epochs=EPOCHS)429430"""431## Make predictions and visualize them432"""433434sample_val_images, sample_val_keypoints = next(iter(validation_dataset))435sample_val_images = sample_val_images[:4]436sample_val_keypoints = sample_val_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE437predictions = model.predict(sample_val_images).reshape(-1, 24, 2) * IMG_SIZE438439# Ground-truth440visualize_keypoints(sample_val_images, sample_val_keypoints)441442# Predictions443visualize_keypoints(sample_val_images, predictions)444445"""446Predictions will likely improve with more training.447"""448449"""450## Going further451452* Try using other augmentation transforms from `imgaug` to investigate how that changes453the results.454* Here, we transferred the features from the pre-trained network linearly that is we did455not [fine-tune](https://keras.io/guides/transfer_learning/) it. You are encouraged to fine-tune it on this task and see if that456improves the performance. You can also try different architectures and see how they457affect the final performance.458"""459460461