Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/captcha_ocr.py
3507 views
1
"""
2
Title: OCR model for reading Captchas
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2020/06/14
5
Last modified: 2024/03/13
6
Description: How to implement an OCR model using CNNs, RNNs and CTC loss.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
9
"""
10
11
"""
12
## Introduction
13
14
This example demonstrates a simple OCR model built with the Functional API. Apart from
15
combining CNN and RNN, it also illustrates how you can instantiate a new layer
16
and use it as an "Endpoint layer" for implementing CTC loss. For a detailed
17
guide to layer subclassing, please check out
18
[this page](https://keras.io/guides/making_new_layers_and_models_via_subclassing/)
19
in the developer guides.
20
"""
21
22
"""
23
## Setup
24
"""
25
26
import os
27
28
os.environ["KERAS_BACKEND"] = "tensorflow"
29
30
import numpy as np
31
import matplotlib.pyplot as plt
32
33
from pathlib import Path
34
35
import tensorflow as tf
36
import keras
37
from keras import ops
38
from keras import layers
39
40
"""
41
## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)
42
Let's download the data.
43
"""
44
45
46
"""shell
47
curl -LO https://github.com/AakashKumarNain/CaptchaCracker/raw/master/captcha_images_v2.zip
48
unzip -qq captcha_images_v2.zip
49
"""
50
51
52
"""
53
The dataset contains 1040 captcha files as `png` images. The label for each sample is a string,
54
the name of the file (minus the file extension).
55
We will map each character in the string to an integer for training the model. Similary,
56
we will need to map the predictions of the model back to strings. For this purpose
57
we will maintain two dictionaries, mapping characters to integers, and integers to characters,
58
respectively.
59
"""
60
61
62
# Path to the data directory
63
data_dir = Path("./captcha_images_v2/")
64
65
# Get list of all the images
66
images = sorted(list(map(str, list(data_dir.glob("*.png")))))
67
labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images]
68
characters = set(char for label in labels for char in label)
69
characters = sorted(list(characters))
70
71
print("Number of images found: ", len(images))
72
print("Number of labels found: ", len(labels))
73
print("Number of unique characters: ", len(characters))
74
print("Characters present: ", characters)
75
76
# Batch size for training and validation
77
batch_size = 16
78
79
# Desired image dimensions
80
img_width = 200
81
img_height = 50
82
83
# Factor by which the image is going to be downsampled
84
# by the convolutional blocks. We will be using two
85
# convolution blocks and each block will have
86
# a pooling layer which downsample the features by a factor of 2.
87
# Hence total downsampling factor would be 4.
88
downsample_factor = 4
89
90
# Maximum length of any captcha in the dataset
91
max_length = max([len(label) for label in labels])
92
93
94
"""
95
## Preprocessing
96
"""
97
98
99
# Mapping characters to integers
100
char_to_num = layers.StringLookup(vocabulary=list(characters), mask_token=None)
101
102
# Mapping integers back to original characters
103
num_to_char = layers.StringLookup(
104
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
105
)
106
107
108
def split_data(images, labels, train_size=0.9, shuffle=True):
109
# 1. Get the total size of the dataset
110
size = len(images)
111
# 2. Make an indices array and shuffle it, if required
112
indices = ops.arange(size)
113
if shuffle:
114
indices = keras.random.shuffle(indices)
115
# 3. Get the size of training samples
116
train_samples = int(size * train_size)
117
# 4. Split data into training and validation sets
118
x_train, y_train = images[indices[:train_samples]], labels[indices[:train_samples]]
119
x_valid, y_valid = images[indices[train_samples:]], labels[indices[train_samples:]]
120
return x_train, x_valid, y_train, y_valid
121
122
123
# Splitting data into training and validation sets
124
x_train, x_valid, y_train, y_valid = split_data(np.array(images), np.array(labels))
125
126
127
def encode_single_sample(img_path, label):
128
# 1. Read image
129
img = tf.io.read_file(img_path)
130
# 2. Decode and convert to grayscale
131
img = tf.io.decode_png(img, channels=1)
132
# 3. Convert to float32 in [0, 1] range
133
img = tf.image.convert_image_dtype(img, tf.float32)
134
# 4. Resize to the desired size
135
img = ops.image.resize(img, [img_height, img_width])
136
# 5. Transpose the image because we want the time
137
# dimension to correspond to the width of the image.
138
img = ops.transpose(img, axes=[1, 0, 2])
139
# 6. Map the characters in label to numbers
140
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
141
# 7. Return a dict as our model is expecting two inputs
142
return {"image": img, "label": label}
143
144
145
"""
146
## Create `Dataset` objects
147
"""
148
149
150
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
151
train_dataset = (
152
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
153
.batch(batch_size)
154
.prefetch(buffer_size=tf.data.AUTOTUNE)
155
)
156
157
validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
158
validation_dataset = (
159
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
160
.batch(batch_size)
161
.prefetch(buffer_size=tf.data.AUTOTUNE)
162
)
163
164
"""
165
## Visualize the data
166
"""
167
168
169
_, ax = plt.subplots(4, 4, figsize=(10, 5))
170
for batch in train_dataset.take(1):
171
images = batch["image"]
172
labels = batch["label"]
173
for i in range(16):
174
img = (images[i] * 255).numpy().astype("uint8")
175
label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8")
176
ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray")
177
ax[i // 4, i % 4].set_title(label)
178
ax[i // 4, i % 4].axis("off")
179
plt.show()
180
181
"""
182
## Model
183
"""
184
185
186
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
187
label_length = ops.cast(ops.squeeze(label_length, axis=-1), dtype="int32")
188
input_length = ops.cast(ops.squeeze(input_length, axis=-1), dtype="int32")
189
sparse_labels = ops.cast(
190
ctc_label_dense_to_sparse(y_true, label_length), dtype="int32"
191
)
192
193
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
194
195
return ops.expand_dims(
196
tf.compat.v1.nn.ctc_loss(
197
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
198
),
199
1,
200
)
201
202
203
def ctc_label_dense_to_sparse(labels, label_lengths):
204
label_shape = ops.shape(labels)
205
num_batches_tns = ops.stack([label_shape[0]])
206
max_num_labels_tns = ops.stack([label_shape[1]])
207
208
def range_less_than(old_input, current_input):
209
return ops.expand_dims(ops.arange(ops.shape(old_input)[1]), 0) < tf.fill(
210
max_num_labels_tns, current_input
211
)
212
213
init = ops.cast(tf.fill([1, label_shape[1]], 0), dtype="bool")
214
dense_mask = tf.compat.v1.scan(
215
range_less_than, label_lengths, initializer=init, parallel_iterations=1
216
)
217
dense_mask = dense_mask[:, 0, :]
218
219
label_array = ops.reshape(
220
ops.tile(ops.arange(0, label_shape[1]), num_batches_tns), label_shape
221
)
222
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
223
224
batch_array = ops.transpose(
225
ops.reshape(
226
ops.tile(ops.arange(0, label_shape[0]), max_num_labels_tns),
227
tf.reverse(label_shape, [0]),
228
)
229
)
230
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
231
indices = ops.transpose(
232
ops.reshape(ops.concatenate([batch_ind, label_ind], axis=0), [2, -1])
233
)
234
235
vals_sparse = tf.compat.v1.gather_nd(labels, indices)
236
237
return tf.SparseTensor(
238
ops.cast(indices, dtype="int64"),
239
vals_sparse,
240
ops.cast(label_shape, dtype="int64"),
241
)
242
243
244
class CTCLayer(layers.Layer):
245
def __init__(self, name=None):
246
super().__init__(name=name)
247
self.loss_fn = ctc_batch_cost
248
249
def call(self, y_true, y_pred):
250
# Compute the training-time loss value and add it
251
# to the layer using `self.add_loss()`.
252
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
253
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
254
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
255
256
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
257
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
258
259
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
260
self.add_loss(loss)
261
262
# At test time, just return the computed predictions
263
return y_pred
264
265
266
def build_model():
267
# Inputs to the model
268
input_img = layers.Input(
269
shape=(img_width, img_height, 1), name="image", dtype="float32"
270
)
271
labels = layers.Input(name="label", shape=(None,), dtype="float32")
272
273
# First conv block
274
x = layers.Conv2D(
275
32,
276
(3, 3),
277
activation="relu",
278
kernel_initializer="he_normal",
279
padding="same",
280
name="Conv1",
281
)(input_img)
282
x = layers.MaxPooling2D((2, 2), name="pool1")(x)
283
284
# Second conv block
285
x = layers.Conv2D(
286
64,
287
(3, 3),
288
activation="relu",
289
kernel_initializer="he_normal",
290
padding="same",
291
name="Conv2",
292
)(x)
293
x = layers.MaxPooling2D((2, 2), name="pool2")(x)
294
295
# We have used two max pool with pool size and strides 2.
296
# Hence, downsampled feature maps are 4x smaller. The number of
297
# filters in the last layer is 64. Reshape accordingly before
298
# passing the output to the RNN part of the model
299
new_shape = ((img_width // 4), (img_height // 4) * 64)
300
x = layers.Reshape(target_shape=new_shape, name="reshape")(x)
301
x = layers.Dense(64, activation="relu", name="dense1")(x)
302
x = layers.Dropout(0.2)(x)
303
304
# RNNs
305
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)
306
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)
307
308
# Output layer
309
x = layers.Dense(
310
len(char_to_num.get_vocabulary()) + 1, activation="softmax", name="dense2"
311
)(x)
312
313
# Add CTC layer for calculating CTC loss at each step
314
output = CTCLayer(name="ctc_loss")(labels, x)
315
316
# Define the model
317
model = keras.models.Model(
318
inputs=[input_img, labels], outputs=output, name="ocr_model_v1"
319
)
320
# Optimizer
321
opt = keras.optimizers.Adam()
322
# Compile the model and return
323
model.compile(optimizer=opt)
324
return model
325
326
327
# Get the model
328
model = build_model()
329
model.summary()
330
331
"""
332
## Training
333
"""
334
335
336
# TODO restore epoch count.
337
epochs = 100
338
early_stopping_patience = 10
339
# Add early stopping
340
early_stopping = keras.callbacks.EarlyStopping(
341
monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True
342
)
343
344
# Train the model
345
history = model.fit(
346
train_dataset,
347
validation_data=validation_dataset,
348
epochs=epochs,
349
callbacks=[early_stopping],
350
)
351
352
353
"""
354
## Inference
355
356
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/ocr-for-captcha)
357
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/ocr-for-captcha).
358
"""
359
360
361
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
362
input_shape = ops.shape(y_pred)
363
num_samples, num_steps = input_shape[0], input_shape[1]
364
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
365
input_length = ops.cast(input_length, dtype="int32")
366
367
if greedy:
368
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
369
inputs=y_pred, sequence_length=input_length
370
)
371
else:
372
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
373
inputs=y_pred,
374
sequence_length=input_length,
375
beam_width=beam_width,
376
top_paths=top_paths,
377
)
378
decoded_dense = []
379
for st in decoded:
380
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
381
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
382
return (decoded_dense, log_prob)
383
384
385
# Get the prediction model by extracting layers till the output layer
386
prediction_model = keras.models.Model(
387
model.input[0], model.get_layer(name="dense2").output
388
)
389
prediction_model.summary()
390
391
392
# A utility function to decode the output of the network
393
def decode_batch_predictions(pred):
394
input_len = np.ones(pred.shape[0]) * pred.shape[1]
395
# Use greedy search. For complex tasks, you can use beam search
396
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
397
:, :max_length
398
]
399
# Iterate over the results and get back the text
400
output_text = []
401
for res in results:
402
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
403
output_text.append(res)
404
return output_text
405
406
407
# Let's check results on some validation samples
408
for batch in validation_dataset.take(1):
409
batch_images = batch["image"]
410
batch_labels = batch["label"]
411
412
preds = prediction_model.predict(batch_images)
413
pred_texts = decode_batch_predictions(preds)
414
415
orig_texts = []
416
for label in batch_labels:
417
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
418
orig_texts.append(label)
419
420
_, ax = plt.subplots(4, 4, figsize=(15, 5))
421
for i in range(len(pred_texts)):
422
img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)
423
img = img.T
424
title = f"Prediction: {pred_texts[i]}"
425
ax[i // 4, i % 4].imshow(img, cmap="gray")
426
ax[i // 4, i % 4].set_title(title)
427
ax[i // 4, i % 4].axis("off")
428
plt.show()
429
430