Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/handwriting_recognition.py
3507 views
1
"""
2
Title: Handwriting recognition
3
Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/08/16
5
Last modified: 2024/09/01
6
Description: Training a handwriting recognition model with variable-length sequences.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example shows how the [Captcha OCR](https://keras.io/examples/vision/captcha_ocr/)
14
example can be extended to the
15
[IAM Dataset](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database),
16
which has variable length ground-truth targets. Each sample in the dataset is an image of some
17
handwritten text, and its corresponding target is the string present in the image.
18
The IAM Dataset is widely used across many OCR benchmarks, so we hope this example can serve as a
19
good starting point for building OCR systems.
20
"""
21
22
"""
23
## Data collection
24
"""
25
26
"""shell
27
wget -q https://github.com/sayakpaul/Handwriting-Recognizer-in-Keras/releases/download/v1.0.0/IAM_Words.zip
28
unzip -qq IAM_Words.zip
29
30
mkdir data
31
mkdir data/words
32
tar -xf IAM_Words/words.tgz -C data/words
33
mv IAM_Words/words.txt data
34
"""
35
36
"""
37
Preview how the dataset is organized. Lines prepended by "#" are just metadata information.
38
"""
39
40
"""shell
41
head -20 data/words.txt
42
"""
43
44
"""
45
## Imports
46
"""
47
48
import keras
49
from keras.layers import StringLookup
50
from keras import ops
51
import matplotlib.pyplot as plt
52
import tensorflow as tf
53
import numpy as np
54
import os
55
56
np.random.seed(42)
57
keras.utils.set_random_seed(42)
58
59
"""
60
## Dataset splitting
61
"""
62
63
base_path = "data"
64
words_list = []
65
66
words = open(f"{base_path}/words.txt", "r").readlines()
67
for line in words:
68
if line[0] == "#":
69
continue
70
if line.split(" ")[1] != "err": # We don't need to deal with errored entries.
71
words_list.append(line)
72
73
len(words_list)
74
75
np.random.shuffle(words_list)
76
77
"""
78
We will split the dataset into three subsets with a 90:5:5 ratio (train:validation:test).
79
"""
80
81
split_idx = int(0.9 * len(words_list))
82
train_samples = words_list[:split_idx]
83
test_samples = words_list[split_idx:]
84
85
val_split_idx = int(0.5 * len(test_samples))
86
validation_samples = test_samples[:val_split_idx]
87
test_samples = test_samples[val_split_idx:]
88
89
assert len(words_list) == len(train_samples) + len(validation_samples) + len(
90
test_samples
91
)
92
93
print(f"Total training samples: {len(train_samples)}")
94
print(f"Total validation samples: {len(validation_samples)}")
95
print(f"Total test samples: {len(test_samples)}")
96
97
"""
98
## Data input pipeline
99
100
We start building our data input pipeline by first preparing the image paths.
101
"""
102
103
base_image_path = os.path.join(base_path, "words")
104
105
106
def get_image_paths_and_labels(samples):
107
paths = []
108
corrected_samples = []
109
for i, file_line in enumerate(samples):
110
line_split = file_line.strip()
111
line_split = line_split.split(" ")
112
113
# Each line split will have this format for the corresponding image:
114
# part1/part1-part2/part1-part2-part3.png
115
image_name = line_split[0]
116
partI = image_name.split("-")[0]
117
partII = image_name.split("-")[1]
118
img_path = os.path.join(
119
base_image_path, partI, partI + "-" + partII, image_name + ".png"
120
)
121
if os.path.getsize(img_path):
122
paths.append(img_path)
123
corrected_samples.append(file_line.split("\n")[0])
124
125
return paths, corrected_samples
126
127
128
train_img_paths, train_labels = get_image_paths_and_labels(train_samples)
129
validation_img_paths, validation_labels = get_image_paths_and_labels(validation_samples)
130
test_img_paths, test_labels = get_image_paths_and_labels(test_samples)
131
132
"""
133
Then we prepare the ground-truth labels.
134
"""
135
136
# Find maximum length and the size of the vocabulary in the training data.
137
train_labels_cleaned = []
138
characters = set()
139
max_len = 0
140
141
for label in train_labels:
142
label = label.split(" ")[-1].strip()
143
for char in label:
144
characters.add(char)
145
146
max_len = max(max_len, len(label))
147
train_labels_cleaned.append(label)
148
149
characters = sorted(list(characters))
150
151
print("Maximum length: ", max_len)
152
print("Vocab size: ", len(characters))
153
154
# Check some label samples.
155
train_labels_cleaned[:10]
156
157
"""
158
Now we clean the validation and the test labels as well.
159
"""
160
161
162
def clean_labels(labels):
163
cleaned_labels = []
164
for label in labels:
165
label = label.split(" ")[-1].strip()
166
cleaned_labels.append(label)
167
return cleaned_labels
168
169
170
validation_labels_cleaned = clean_labels(validation_labels)
171
test_labels_cleaned = clean_labels(test_labels)
172
173
"""
174
### Building the character vocabulary
175
176
Keras provides different preprocessing layers to deal with different modalities of data.
177
[This guide](https://keras.io/api/layers/preprocessing_layers/) provides a comprehensive introduction.
178
Our example involves preprocessing labels at the character
179
level. This means that if there are two labels, e.g. "cat" and "dog", then our character
180
vocabulary should be {a, c, d, g, o, t} (without any special tokens). We use the
181
[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup/)
182
layer for this purpose.
183
"""
184
185
186
AUTOTUNE = tf.data.AUTOTUNE
187
188
# Mapping characters to integers.
189
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
190
191
# Mapping integers back to original characters.
192
num_to_char = StringLookup(
193
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
194
)
195
196
"""
197
### Resizing images without distortion
198
199
Instead of square images, many OCR models work with rectangular images. This will become
200
clearer in a moment when we will visualize a few samples from the dataset. While
201
aspect-unaware resizing square images does not introduce a significant amount of
202
distortion this is not the case for rectangular images. But resizing images to a uniform
203
size is a requirement for mini-batching. So we need to perform our resizing such that
204
the following criteria are met:
205
206
* Aspect ratio is preserved.
207
* Content of the images is not affected.
208
"""
209
210
211
def distortion_free_resize(image, img_size):
212
w, h = img_size
213
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
214
215
# Check tha amount of padding needed to be done.
216
pad_height = h - ops.shape(image)[0]
217
pad_width = w - ops.shape(image)[1]
218
219
# Only necessary if you want to do same amount of padding on both sides.
220
if pad_height % 2 != 0:
221
height = pad_height // 2
222
pad_height_top = height + 1
223
pad_height_bottom = height
224
else:
225
pad_height_top = pad_height_bottom = pad_height // 2
226
227
if pad_width % 2 != 0:
228
width = pad_width // 2
229
pad_width_left = width + 1
230
pad_width_right = width
231
else:
232
pad_width_left = pad_width_right = pad_width // 2
233
234
image = tf.pad(
235
image,
236
paddings=[
237
[pad_height_top, pad_height_bottom],
238
[pad_width_left, pad_width_right],
239
[0, 0],
240
],
241
)
242
243
image = ops.transpose(image, (1, 0, 2))
244
image = tf.image.flip_left_right(image)
245
return image
246
247
248
"""
249
If we just go with the plain resizing then the images would look like so:
250
251
![](https://i.imgur.com/eqq3s4N.png)
252
253
Notice how this resizing would have introduced unnecessary stretching.
254
"""
255
256
"""
257
### Putting the utilities together
258
"""
259
260
batch_size = 64
261
padding_token = 99
262
image_width = 128
263
image_height = 32
264
265
266
def preprocess_image(image_path, img_size=(image_width, image_height)):
267
image = tf.io.read_file(image_path)
268
image = tf.image.decode_png(image, 1)
269
image = distortion_free_resize(image, img_size)
270
image = ops.cast(image, tf.float32) / 255.0
271
return image
272
273
274
def vectorize_label(label):
275
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
276
length = ops.shape(label)[0]
277
pad_amount = max_len - length
278
label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
279
return label
280
281
282
def process_images_labels(image_path, label):
283
image = preprocess_image(image_path)
284
label = vectorize_label(label)
285
return {"image": image, "label": label}
286
287
288
def prepare_dataset(image_paths, labels):
289
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(
290
process_images_labels, num_parallel_calls=AUTOTUNE
291
)
292
return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)
293
294
295
"""
296
## Prepare `tf.data.Dataset` objects
297
"""
298
299
train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)
300
validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)
301
test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)
302
303
"""
304
## Visualize a few samples
305
"""
306
307
for data in train_ds.take(1):
308
images, labels = data["image"], data["label"]
309
310
_, ax = plt.subplots(4, 4, figsize=(15, 8))
311
312
for i in range(16):
313
img = images[i]
314
img = tf.image.flip_left_right(img)
315
img = ops.transpose(img, (1, 0, 2))
316
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
317
img = img[:, :, 0]
318
319
# Gather indices where label!= padding_token.
320
label = labels[i]
321
indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
322
# Convert to string.
323
label = tf.strings.reduce_join(num_to_char(indices))
324
label = label.numpy().decode("utf-8")
325
326
ax[i // 4, i % 4].imshow(img, cmap="gray")
327
ax[i // 4, i % 4].set_title(label)
328
ax[i // 4, i % 4].axis("off")
329
330
331
plt.show()
332
333
"""
334
You will notice that the content of original image is kept as faithful as possible and has
335
been padded accordingly.
336
"""
337
338
"""
339
## Model
340
341
Our model will use the CTC loss as an endpoint layer. For a detailed understanding of the
342
CTC loss, refer to [this post](https://distill.pub/2017/ctc/).
343
"""
344
345
346
class CTCLayer(keras.layers.Layer):
347
def __init__(self, name=None):
348
super().__init__(name=name)
349
self.loss_fn = tf.keras.backend.ctc_batch_cost
350
351
def call(self, y_true, y_pred):
352
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
353
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
354
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
355
356
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
357
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
358
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
359
self.add_loss(loss)
360
361
# At test time, just return the computed predictions.
362
return y_pred
363
364
365
def build_model():
366
# Inputs to the model
367
input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
368
labels = keras.layers.Input(name="label", shape=(None,))
369
370
# First conv block.
371
x = keras.layers.Conv2D(
372
32,
373
(3, 3),
374
activation="relu",
375
kernel_initializer="he_normal",
376
padding="same",
377
name="Conv1",
378
)(input_img)
379
x = keras.layers.MaxPooling2D((2, 2), name="pool1")(x)
380
381
# Second conv block.
382
x = keras.layers.Conv2D(
383
64,
384
(3, 3),
385
activation="relu",
386
kernel_initializer="he_normal",
387
padding="same",
388
name="Conv2",
389
)(x)
390
x = keras.layers.MaxPooling2D((2, 2), name="pool2")(x)
391
392
# We have used two max pool with pool size and strides 2.
393
# Hence, downsampled feature maps are 4x smaller. The number of
394
# filters in the last layer is 64. Reshape accordingly before
395
# passing the output to the RNN part of the model.
396
new_shape = ((image_width // 4), (image_height // 4) * 64)
397
x = keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
398
x = keras.layers.Dense(64, activation="relu", name="dense1")(x)
399
x = keras.layers.Dropout(0.2)(x)
400
401
# RNNs.
402
x = keras.layers.Bidirectional(
403
keras.layers.LSTM(128, return_sequences=True, dropout=0.25)
404
)(x)
405
x = keras.layers.Bidirectional(
406
keras.layers.LSTM(64, return_sequences=True, dropout=0.25)
407
)(x)
408
409
# +2 is to account for the two special tokens introduced by the CTC loss.
410
# The recommendation comes here: https://git.io/J0eXP.
411
x = keras.layers.Dense(
412
len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="dense2"
413
)(x)
414
415
# Add CTC layer for calculating CTC loss at each step.
416
output = CTCLayer(name="ctc_loss")(labels, x)
417
418
# Define the model.
419
model = keras.models.Model(
420
inputs=[input_img, labels], outputs=output, name="handwriting_recognizer"
421
)
422
# Optimizer.
423
opt = keras.optimizers.Adam()
424
# Compile the model and return.
425
model.compile(optimizer=opt)
426
return model
427
428
429
# Get the model.
430
model = build_model()
431
model.summary()
432
433
"""
434
## Evaluation metric
435
436
[Edit Distance](https://en.wikipedia.org/wiki/Edit_distance)
437
is the most widely used metric for evaluating OCR models. In this section, we will
438
implement it and use it as a callback to monitor our model.
439
"""
440
441
"""
442
We first segregate the validation images and their labels for convenience.
443
"""
444
validation_images = []
445
validation_labels = []
446
447
for batch in validation_ds:
448
validation_images.append(batch["image"])
449
validation_labels.append(batch["label"])
450
451
"""
452
Now, we create a callback to monitor the edit distances.
453
"""
454
455
456
def calculate_edit_distance(labels, predictions):
457
# Get a single batch and convert its labels to sparse tensors.
458
saprse_labels = ops.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
459
460
# Make predictions and convert them to sparse tensors.
461
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
462
predictions_decoded = keras.ops.nn.ctc_decode(
463
predictions, sequence_lengths=input_len
464
)[0][0][:, :max_len]
465
sparse_predictions = ops.cast(
466
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
467
)
468
469
# Compute individual edit distances and average them out.
470
edit_distances = tf.edit_distance(
471
sparse_predictions, saprse_labels, normalize=False
472
)
473
return tf.reduce_mean(edit_distances)
474
475
476
class EditDistanceCallback(keras.callbacks.Callback):
477
def __init__(self, pred_model):
478
super().__init__()
479
self.prediction_model = pred_model
480
481
def on_epoch_end(self, epoch, logs=None):
482
edit_distances = []
483
484
for i in range(len(validation_images)):
485
labels = validation_labels[i]
486
predictions = self.prediction_model.predict(validation_images[i])
487
edit_distances.append(calculate_edit_distance(labels, predictions).numpy())
488
489
print(
490
f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"
491
)
492
493
494
"""
495
## Training
496
497
Now we are ready to kick off model training.
498
"""
499
500
epochs = 10 # To get good results this should be at least 50.
501
502
model = build_model()
503
prediction_model = keras.models.Model(
504
model.get_layer(name="image").output, model.get_layer(name="dense2").output
505
)
506
edit_distance_callback = EditDistanceCallback(prediction_model)
507
508
# Train the model.
509
history = model.fit(
510
train_ds,
511
validation_data=validation_ds,
512
epochs=epochs,
513
callbacks=[edit_distance_callback],
514
)
515
516
517
"""
518
## Inference
519
"""
520
521
522
# A utility function to decode the output of the network.
523
def decode_batch_predictions(pred):
524
input_len = np.ones(pred.shape[0]) * pred.shape[1]
525
# Use greedy search. For complex tasks, you can use beam search.
526
results = keras.ops.nn.ctc_decode(pred, sequence_lengths=input_len)[0][0][
527
:, :max_len
528
]
529
# Iterate over the results and get back the text.
530
output_text = []
531
for res in results:
532
res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
533
res = (
534
tf.strings.reduce_join(num_to_char(res))
535
.numpy()
536
.decode("utf-8")
537
.replace("[UNK]", "")
538
)
539
output_text.append(res)
540
return output_text
541
542
543
# Let's check results on some test samples.
544
for batch in test_ds.take(1):
545
batch_images = batch["image"]
546
_, ax = plt.subplots(4, 4, figsize=(15, 8))
547
548
preds = prediction_model.predict(batch_images)
549
pred_texts = decode_batch_predictions(preds)
550
551
for i in range(16):
552
img = batch_images[i]
553
img = tf.image.flip_left_right(img)
554
img = ops.transpose(img, (1, 0, 2))
555
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
556
img = img[:, :, 0]
557
558
title = f"Prediction: {pred_texts[i]}"
559
ax[i // 4, i % 4].imshow(img, cmap="gray")
560
ax[i // 4, i % 4].set_title(title)
561
ax[i // 4, i % 4].axis("off")
562
563
plt.show()
564
565
"""
566
To get better results the model should be trained for at least 50 epochs.
567
"""
568
569
"""
570
## Final remarks
571
572
* The `prediction_model` is fully compatible with TensorFlow Lite. If you are interested,
573
you can use it inside a mobile application. You may find
574
[this notebook](https://github.com/tulasiram58827/ocr_tflite/blob/main/colabs/captcha_ocr_tflite.ipynb)
575
to be useful in this regard.
576
* Not all the training examples are perfectly aligned as observed in this example. This
577
can hurt model performance for complex sequences. To this end, we can leverage
578
Spatial Transformer Networks ([Jaderberg et al.](https://arxiv.org/abs/1506.02025))
579
that can help the model learn affine transformations that maximize its performance.
580
"""
581
582