Path: blob/master/examples/vision/handwriting_recognition.py
3507 views
"""1Title: Handwriting recognition2Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/08/164Last modified: 2024/09/015Description: Training a handwriting recognition model with variable-length sequences.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how the [Captcha OCR](https://keras.io/examples/vision/captcha_ocr/)13example can be extended to the14[IAM Dataset](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database),15which has variable length ground-truth targets. Each sample in the dataset is an image of some16handwritten text, and its corresponding target is the string present in the image.17The IAM Dataset is widely used across many OCR benchmarks, so we hope this example can serve as a18good starting point for building OCR systems.19"""2021"""22## Data collection23"""2425"""shell26wget -q https://github.com/sayakpaul/Handwriting-Recognizer-in-Keras/releases/download/v1.0.0/IAM_Words.zip27unzip -qq IAM_Words.zip2829mkdir data30mkdir data/words31tar -xf IAM_Words/words.tgz -C data/words32mv IAM_Words/words.txt data33"""3435"""36Preview how the dataset is organized. Lines prepended by "#" are just metadata information.37"""3839"""shell40head -20 data/words.txt41"""4243"""44## Imports45"""4647import keras48from keras.layers import StringLookup49from keras import ops50import matplotlib.pyplot as plt51import tensorflow as tf52import numpy as np53import os5455np.random.seed(42)56keras.utils.set_random_seed(42)5758"""59## Dataset splitting60"""6162base_path = "data"63words_list = []6465words = open(f"{base_path}/words.txt", "r").readlines()66for line in words:67if line[0] == "#":68continue69if line.split(" ")[1] != "err": # We don't need to deal with errored entries.70words_list.append(line)7172len(words_list)7374np.random.shuffle(words_list)7576"""77We will split the dataset into three subsets with a 90:5:5 ratio (train:validation:test).78"""7980split_idx = int(0.9 * len(words_list))81train_samples = words_list[:split_idx]82test_samples = words_list[split_idx:]8384val_split_idx = int(0.5 * len(test_samples))85validation_samples = test_samples[:val_split_idx]86test_samples = test_samples[val_split_idx:]8788assert len(words_list) == len(train_samples) + len(validation_samples) + len(89test_samples90)9192print(f"Total training samples: {len(train_samples)}")93print(f"Total validation samples: {len(validation_samples)}")94print(f"Total test samples: {len(test_samples)}")9596"""97## Data input pipeline9899We start building our data input pipeline by first preparing the image paths.100"""101102base_image_path = os.path.join(base_path, "words")103104105def get_image_paths_and_labels(samples):106paths = []107corrected_samples = []108for i, file_line in enumerate(samples):109line_split = file_line.strip()110line_split = line_split.split(" ")111112# Each line split will have this format for the corresponding image:113# part1/part1-part2/part1-part2-part3.png114image_name = line_split[0]115partI = image_name.split("-")[0]116partII = image_name.split("-")[1]117img_path = os.path.join(118base_image_path, partI, partI + "-" + partII, image_name + ".png"119)120if os.path.getsize(img_path):121paths.append(img_path)122corrected_samples.append(file_line.split("\n")[0])123124return paths, corrected_samples125126127train_img_paths, train_labels = get_image_paths_and_labels(train_samples)128validation_img_paths, validation_labels = get_image_paths_and_labels(validation_samples)129test_img_paths, test_labels = get_image_paths_and_labels(test_samples)130131"""132Then we prepare the ground-truth labels.133"""134135# Find maximum length and the size of the vocabulary in the training data.136train_labels_cleaned = []137characters = set()138max_len = 0139140for label in train_labels:141label = label.split(" ")[-1].strip()142for char in label:143characters.add(char)144145max_len = max(max_len, len(label))146train_labels_cleaned.append(label)147148characters = sorted(list(characters))149150print("Maximum length: ", max_len)151print("Vocab size: ", len(characters))152153# Check some label samples.154train_labels_cleaned[:10]155156"""157Now we clean the validation and the test labels as well.158"""159160161def clean_labels(labels):162cleaned_labels = []163for label in labels:164label = label.split(" ")[-1].strip()165cleaned_labels.append(label)166return cleaned_labels167168169validation_labels_cleaned = clean_labels(validation_labels)170test_labels_cleaned = clean_labels(test_labels)171172"""173### Building the character vocabulary174175Keras provides different preprocessing layers to deal with different modalities of data.176[This guide](https://keras.io/api/layers/preprocessing_layers/) provides a comprehensive introduction.177Our example involves preprocessing labels at the character178level. This means that if there are two labels, e.g. "cat" and "dog", then our character179vocabulary should be {a, c, d, g, o, t} (without any special tokens). We use the180[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup/)181layer for this purpose.182"""183184185AUTOTUNE = tf.data.AUTOTUNE186187# Mapping characters to integers.188char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)189190# Mapping integers back to original characters.191num_to_char = StringLookup(192vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True193)194195"""196### Resizing images without distortion197198Instead of square images, many OCR models work with rectangular images. This will become199clearer in a moment when we will visualize a few samples from the dataset. While200aspect-unaware resizing square images does not introduce a significant amount of201distortion this is not the case for rectangular images. But resizing images to a uniform202size is a requirement for mini-batching. So we need to perform our resizing such that203the following criteria are met:204205* Aspect ratio is preserved.206* Content of the images is not affected.207"""208209210def distortion_free_resize(image, img_size):211w, h = img_size212image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)213214# Check tha amount of padding needed to be done.215pad_height = h - ops.shape(image)[0]216pad_width = w - ops.shape(image)[1]217218# Only necessary if you want to do same amount of padding on both sides.219if pad_height % 2 != 0:220height = pad_height // 2221pad_height_top = height + 1222pad_height_bottom = height223else:224pad_height_top = pad_height_bottom = pad_height // 2225226if pad_width % 2 != 0:227width = pad_width // 2228pad_width_left = width + 1229pad_width_right = width230else:231pad_width_left = pad_width_right = pad_width // 2232233image = tf.pad(234image,235paddings=[236[pad_height_top, pad_height_bottom],237[pad_width_left, pad_width_right],238[0, 0],239],240)241242image = ops.transpose(image, (1, 0, 2))243image = tf.image.flip_left_right(image)244return image245246247"""248If we just go with the plain resizing then the images would look like so:249250251252Notice how this resizing would have introduced unnecessary stretching.253"""254255"""256### Putting the utilities together257"""258259batch_size = 64260padding_token = 99261image_width = 128262image_height = 32263264265def preprocess_image(image_path, img_size=(image_width, image_height)):266image = tf.io.read_file(image_path)267image = tf.image.decode_png(image, 1)268image = distortion_free_resize(image, img_size)269image = ops.cast(image, tf.float32) / 255.0270return image271272273def vectorize_label(label):274label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))275length = ops.shape(label)[0]276pad_amount = max_len - length277label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)278return label279280281def process_images_labels(image_path, label):282image = preprocess_image(image_path)283label = vectorize_label(label)284return {"image": image, "label": label}285286287def prepare_dataset(image_paths, labels):288dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(289process_images_labels, num_parallel_calls=AUTOTUNE290)291return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)292293294"""295## Prepare `tf.data.Dataset` objects296"""297298train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)299validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)300test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)301302"""303## Visualize a few samples304"""305306for data in train_ds.take(1):307images, labels = data["image"], data["label"]308309_, ax = plt.subplots(4, 4, figsize=(15, 8))310311for i in range(16):312img = images[i]313img = tf.image.flip_left_right(img)314img = ops.transpose(img, (1, 0, 2))315img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)316img = img[:, :, 0]317318# Gather indices where label!= padding_token.319label = labels[i]320indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))321# Convert to string.322label = tf.strings.reduce_join(num_to_char(indices))323label = label.numpy().decode("utf-8")324325ax[i // 4, i % 4].imshow(img, cmap="gray")326ax[i // 4, i % 4].set_title(label)327ax[i // 4, i % 4].axis("off")328329330plt.show()331332"""333You will notice that the content of original image is kept as faithful as possible and has334been padded accordingly.335"""336337"""338## Model339340Our model will use the CTC loss as an endpoint layer. For a detailed understanding of the341CTC loss, refer to [this post](https://distill.pub/2017/ctc/).342"""343344345class CTCLayer(keras.layers.Layer):346def __init__(self, name=None):347super().__init__(name=name)348self.loss_fn = tf.keras.backend.ctc_batch_cost349350def call(self, y_true, y_pred):351batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")352input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")353label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")354355input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")356label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")357loss = self.loss_fn(y_true, y_pred, input_length, label_length)358self.add_loss(loss)359360# At test time, just return the computed predictions.361return y_pred362363364def build_model():365# Inputs to the model366input_img = keras.Input(shape=(image_width, image_height, 1), name="image")367labels = keras.layers.Input(name="label", shape=(None,))368369# First conv block.370x = keras.layers.Conv2D(37132,372(3, 3),373activation="relu",374kernel_initializer="he_normal",375padding="same",376name="Conv1",377)(input_img)378x = keras.layers.MaxPooling2D((2, 2), name="pool1")(x)379380# Second conv block.381x = keras.layers.Conv2D(38264,383(3, 3),384activation="relu",385kernel_initializer="he_normal",386padding="same",387name="Conv2",388)(x)389x = keras.layers.MaxPooling2D((2, 2), name="pool2")(x)390391# We have used two max pool with pool size and strides 2.392# Hence, downsampled feature maps are 4x smaller. The number of393# filters in the last layer is 64. Reshape accordingly before394# passing the output to the RNN part of the model.395new_shape = ((image_width // 4), (image_height // 4) * 64)396x = keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)397x = keras.layers.Dense(64, activation="relu", name="dense1")(x)398x = keras.layers.Dropout(0.2)(x)399400# RNNs.401x = keras.layers.Bidirectional(402keras.layers.LSTM(128, return_sequences=True, dropout=0.25)403)(x)404x = keras.layers.Bidirectional(405keras.layers.LSTM(64, return_sequences=True, dropout=0.25)406)(x)407408# +2 is to account for the two special tokens introduced by the CTC loss.409# The recommendation comes here: https://git.io/J0eXP.410x = keras.layers.Dense(411len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="dense2"412)(x)413414# Add CTC layer for calculating CTC loss at each step.415output = CTCLayer(name="ctc_loss")(labels, x)416417# Define the model.418model = keras.models.Model(419inputs=[input_img, labels], outputs=output, name="handwriting_recognizer"420)421# Optimizer.422opt = keras.optimizers.Adam()423# Compile the model and return.424model.compile(optimizer=opt)425return model426427428# Get the model.429model = build_model()430model.summary()431432"""433## Evaluation metric434435[Edit Distance](https://en.wikipedia.org/wiki/Edit_distance)436is the most widely used metric for evaluating OCR models. In this section, we will437implement it and use it as a callback to monitor our model.438"""439440"""441We first segregate the validation images and their labels for convenience.442"""443validation_images = []444validation_labels = []445446for batch in validation_ds:447validation_images.append(batch["image"])448validation_labels.append(batch["label"])449450"""451Now, we create a callback to monitor the edit distances.452"""453454455def calculate_edit_distance(labels, predictions):456# Get a single batch and convert its labels to sparse tensors.457saprse_labels = ops.cast(tf.sparse.from_dense(labels), dtype=tf.int64)458459# Make predictions and convert them to sparse tensors.460input_len = np.ones(predictions.shape[0]) * predictions.shape[1]461predictions_decoded = keras.ops.nn.ctc_decode(462predictions, sequence_lengths=input_len463)[0][0][:, :max_len]464sparse_predictions = ops.cast(465tf.sparse.from_dense(predictions_decoded), dtype=tf.int64466)467468# Compute individual edit distances and average them out.469edit_distances = tf.edit_distance(470sparse_predictions, saprse_labels, normalize=False471)472return tf.reduce_mean(edit_distances)473474475class EditDistanceCallback(keras.callbacks.Callback):476def __init__(self, pred_model):477super().__init__()478self.prediction_model = pred_model479480def on_epoch_end(self, epoch, logs=None):481edit_distances = []482483for i in range(len(validation_images)):484labels = validation_labels[i]485predictions = self.prediction_model.predict(validation_images[i])486edit_distances.append(calculate_edit_distance(labels, predictions).numpy())487488print(489f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"490)491492493"""494## Training495496Now we are ready to kick off model training.497"""498499epochs = 10 # To get good results this should be at least 50.500501model = build_model()502prediction_model = keras.models.Model(503model.get_layer(name="image").output, model.get_layer(name="dense2").output504)505edit_distance_callback = EditDistanceCallback(prediction_model)506507# Train the model.508history = model.fit(509train_ds,510validation_data=validation_ds,511epochs=epochs,512callbacks=[edit_distance_callback],513)514515516"""517## Inference518"""519520521# A utility function to decode the output of the network.522def decode_batch_predictions(pred):523input_len = np.ones(pred.shape[0]) * pred.shape[1]524# Use greedy search. For complex tasks, you can use beam search.525results = keras.ops.nn.ctc_decode(pred, sequence_lengths=input_len)[0][0][526:, :max_len527]528# Iterate over the results and get back the text.529output_text = []530for res in results:531res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))532res = (533tf.strings.reduce_join(num_to_char(res))534.numpy()535.decode("utf-8")536.replace("[UNK]", "")537)538output_text.append(res)539return output_text540541542# Let's check results on some test samples.543for batch in test_ds.take(1):544batch_images = batch["image"]545_, ax = plt.subplots(4, 4, figsize=(15, 8))546547preds = prediction_model.predict(batch_images)548pred_texts = decode_batch_predictions(preds)549550for i in range(16):551img = batch_images[i]552img = tf.image.flip_left_right(img)553img = ops.transpose(img, (1, 0, 2))554img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)555img = img[:, :, 0]556557title = f"Prediction: {pred_texts[i]}"558ax[i // 4, i % 4].imshow(img, cmap="gray")559ax[i // 4, i % 4].set_title(title)560ax[i // 4, i % 4].axis("off")561562plt.show()563564"""565To get better results the model should be trained for at least 50 epochs.566"""567568"""569## Final remarks570571* The `prediction_model` is fully compatible with TensorFlow Lite. If you are interested,572you can use it inside a mobile application. You may find573[this notebook](https://github.com/tulasiram58827/ocr_tflite/blob/main/colabs/captcha_ocr_tflite.ipynb)574to be useful in this regard.575* Not all the training examples are perfectly aligned as observed in this example. This576can hurt model performance for complex sequences. To this end, we can leverage577Spatial Transformer Networks ([Jaderberg et al.](https://arxiv.org/abs/1506.02025))578that can help the model learn affine transformations that maximize its performance.579"""580581582