Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/ctc_asr.py
3507 views
1
"""
2
Title: Automatic Speech Recognition using CTC
3
Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
4
Date created: 2021/09/26
5
Last modified: 2021/09/26
6
Description: Training a CTC-based model for automatic speech recognition.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Speech recognition is an interdisciplinary subfield of computer science
14
and computational linguistics that develops methodologies and technologies
15
that enable the recognition and translation of spoken language into text
16
by computers. It is also known as automatic speech recognition (ASR),
17
computer speech recognition or speech to text (STT). It incorporates
18
knowledge and research in the computer science, linguistics and computer
19
engineering fields.
20
21
This demonstration shows how to combine a 2D CNN, RNN and a Connectionist
22
Temporal Classification (CTC) loss to build an ASR. CTC is an algorithm
23
used to train deep neural networks in speech recognition, handwriting
24
recognition and other sequence problems. CTC is used when we don’t know
25
how the input aligns with the output (how the characters in the transcript
26
align to the audio). The model we create is similar to
27
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
28
29
We will use the LJSpeech dataset from the
30
[LibriVox](https://librivox.org/) project. It consists of short
31
audio clips of a single speaker reading passages from 7 non-fiction books.
32
33
We will evaluate the quality of the model using
34
[Word Error Rate (WER)](https://en.wikipedia.org/wiki/Word_error_rate).
35
WER is obtained by adding up
36
the substitutions, insertions, and deletions that occur in a sequence of
37
recognized words. Divide that number by the total number of words originally
38
spoken. The result is the WER. To get the WER score you need to install the
39
[jiwer](https://pypi.org/project/jiwer/) package. You can use the following command line:
40
41
```
42
pip install jiwer
43
```
44
45
**References:**
46
47
- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
48
- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
49
- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
50
- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
51
52
"""
53
54
"""
55
## Setup
56
"""
57
58
import pandas as pd
59
import numpy as np
60
import tensorflow as tf
61
from tensorflow import keras
62
from tensorflow.keras import layers
63
import matplotlib.pyplot as plt
64
from IPython import display
65
from jiwer import wer
66
67
68
"""
69
## Load the LJSpeech Dataset
70
71
Let's download the [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/).
72
The dataset contains 13,100 audio files as `wav` files in the `/wavs/` folder.
73
The label (transcript) for each audio file is a string
74
given in the `metadata.csv` file. The fields are:
75
76
- **ID**: this is the name of the corresponding .wav file
77
- **Transcription**: words spoken by the reader (UTF-8)
78
- **Normalized transcription**: transcription with numbers,
79
ordinals, and monetary units expanded into full words (UTF-8).
80
81
For this demo we will use on the "Normalized transcription" field.
82
83
Each audio file is a single-channel 16-bit PCM WAV with a sample rate of 22,050 Hz.
84
"""
85
86
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
87
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
88
wavs_path = data_path + "/wavs/"
89
metadata_path = data_path + "/metadata.csv"
90
91
92
# Read metadata file and parse it
93
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
94
metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]
95
metadata_df = metadata_df[["file_name", "normalized_transcription"]]
96
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
97
metadata_df.head(3)
98
99
100
"""
101
We now split the data into training and validation set.
102
"""
103
104
split = int(len(metadata_df) * 0.90)
105
df_train = metadata_df[:split]
106
df_val = metadata_df[split:]
107
108
print(f"Size of the training set: {len(df_train)}")
109
print(f"Size of the training set: {len(df_val)}")
110
111
112
"""
113
## Preprocessing
114
115
We first prepare the vocabulary to be used.
116
"""
117
118
# The set of characters accepted in the transcription.
119
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
120
# Mapping characters to integers
121
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
122
# Mapping integers back to original characters
123
num_to_char = keras.layers.StringLookup(
124
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
125
)
126
127
print(
128
f"The vocabulary is: {char_to_num.get_vocabulary()} "
129
f"(size ={char_to_num.vocabulary_size()})"
130
)
131
132
"""
133
Next, we create the function that describes the transformation that we apply to each
134
element of our dataset.
135
"""
136
137
# An integer scalar Tensor. The window length in samples.
138
frame_length = 256
139
# An integer scalar Tensor. The number of samples to step.
140
frame_step = 160
141
# An integer scalar Tensor. The size of the FFT to apply.
142
# If not provided, uses the smallest power of 2 enclosing frame_length.
143
fft_length = 384
144
145
146
def encode_single_sample(wav_file, label):
147
###########################################
148
## Process the Audio
149
##########################################
150
# 1. Read wav file
151
file = tf.io.read_file(wavs_path + wav_file + ".wav")
152
# 2. Decode the wav file
153
audio, _ = tf.audio.decode_wav(file)
154
audio = tf.squeeze(audio, axis=-1)
155
# 3. Change type to float
156
audio = tf.cast(audio, tf.float32)
157
# 4. Get the spectrogram
158
spectrogram = tf.signal.stft(
159
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
160
)
161
# 5. We only need the magnitude, which can be derived by applying tf.abs
162
spectrogram = tf.abs(spectrogram)
163
spectrogram = tf.math.pow(spectrogram, 0.5)
164
# 6. normalisation
165
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
166
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
167
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
168
###########################################
169
## Process the label
170
##########################################
171
# 7. Convert label to Lower case
172
label = tf.strings.lower(label)
173
# 8. Split the label
174
label = tf.strings.unicode_split(label, input_encoding="UTF-8")
175
# 9. Map the characters in label to numbers
176
label = char_to_num(label)
177
# 10. Return a dict as our model is expecting two inputs
178
return spectrogram, label
179
180
181
"""
182
## Creating `Dataset` objects
183
184
We create a `tf.data.Dataset` object that yields
185
the transformed elements, in the same order as they
186
appeared in the input.
187
"""
188
189
batch_size = 32
190
# Define the training dataset
191
train_dataset = tf.data.Dataset.from_tensor_slices(
192
(list(df_train["file_name"]), list(df_train["normalized_transcription"]))
193
)
194
train_dataset = (
195
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
196
.padded_batch(batch_size)
197
.prefetch(buffer_size=tf.data.AUTOTUNE)
198
)
199
200
# Define the validation dataset
201
validation_dataset = tf.data.Dataset.from_tensor_slices(
202
(list(df_val["file_name"]), list(df_val["normalized_transcription"]))
203
)
204
validation_dataset = (
205
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
206
.padded_batch(batch_size)
207
.prefetch(buffer_size=tf.data.AUTOTUNE)
208
)
209
210
211
"""
212
## Visualize the data
213
214
Let's visualize an example in our dataset, including the
215
audio clip, the spectrogram and the corresponding label.
216
"""
217
218
fig = plt.figure(figsize=(8, 5))
219
for batch in train_dataset.take(1):
220
spectrogram = batch[0][0].numpy()
221
spectrogram = np.array([np.trim_zeros(x) for x in np.transpose(spectrogram)])
222
label = batch[1][0]
223
# Spectrogram
224
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
225
ax = plt.subplot(2, 1, 1)
226
ax.imshow(spectrogram, vmax=1)
227
ax.set_title(label)
228
ax.axis("off")
229
# Wav
230
file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav")
231
audio, _ = tf.audio.decode_wav(file)
232
audio = audio.numpy()
233
ax = plt.subplot(2, 1, 2)
234
plt.plot(audio)
235
ax.set_title("Signal Wave")
236
ax.set_xlim(0, len(audio))
237
display.display(display.Audio(np.transpose(audio), rate=16000))
238
plt.show()
239
240
"""
241
## Model
242
243
We first define the CTC Loss function.
244
"""
245
246
247
def CTCLoss(y_true, y_pred):
248
# Compute the training-time loss value
249
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
250
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
251
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
252
253
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
254
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
255
256
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
257
return loss
258
259
260
"""
261
We now define our model. We will define a model similar to
262
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
263
"""
264
265
266
def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
267
"""Model similar to DeepSpeech2."""
268
# Model's input
269
input_spectrogram = layers.Input((None, input_dim), name="input")
270
# Expand the dimension to use 2D CNN.
271
x = layers.Reshape((-1, input_dim, 1), name="expand_dim")(input_spectrogram)
272
# Convolution layer 1
273
x = layers.Conv2D(
274
filters=32,
275
kernel_size=[11, 41],
276
strides=[2, 2],
277
padding="same",
278
use_bias=False,
279
name="conv_1",
280
)(x)
281
x = layers.BatchNormalization(name="conv_1_bn")(x)
282
x = layers.ReLU(name="conv_1_relu")(x)
283
# Convolution layer 2
284
x = layers.Conv2D(
285
filters=32,
286
kernel_size=[11, 21],
287
strides=[1, 2],
288
padding="same",
289
use_bias=False,
290
name="conv_2",
291
)(x)
292
x = layers.BatchNormalization(name="conv_2_bn")(x)
293
x = layers.ReLU(name="conv_2_relu")(x)
294
# Reshape the resulted volume to feed the RNNs layers
295
x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x)
296
# RNN layers
297
for i in range(1, rnn_layers + 1):
298
recurrent = layers.GRU(
299
units=rnn_units,
300
activation="tanh",
301
recurrent_activation="sigmoid",
302
use_bias=True,
303
return_sequences=True,
304
reset_after=True,
305
name=f"gru_{i}",
306
)
307
x = layers.Bidirectional(
308
recurrent, name=f"bidirectional_{i}", merge_mode="concat"
309
)(x)
310
if i < rnn_layers:
311
x = layers.Dropout(rate=0.5)(x)
312
# Dense layer
313
x = layers.Dense(units=rnn_units * 2, name="dense_1")(x)
314
x = layers.ReLU(name="dense_1_relu")(x)
315
x = layers.Dropout(rate=0.5)(x)
316
# Classification layer
317
output = layers.Dense(units=output_dim + 1, activation="softmax")(x)
318
# Model
319
model = keras.Model(input_spectrogram, output, name="DeepSpeech_2")
320
# Optimizer
321
opt = keras.optimizers.Adam(learning_rate=1e-4)
322
# Compile the model and return
323
model.compile(optimizer=opt, loss=CTCLoss)
324
return model
325
326
327
# Get the model
328
model = build_model(
329
input_dim=fft_length // 2 + 1,
330
output_dim=char_to_num.vocabulary_size(),
331
rnn_units=512,
332
)
333
model.summary(line_length=110)
334
335
"""
336
## Training and Evaluating
337
"""
338
339
340
# A utility function to decode the output of the network
341
def decode_batch_predictions(pred):
342
input_len = np.ones(pred.shape[0]) * pred.shape[1]
343
# Use greedy search. For complex tasks, you can use beam search
344
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
345
# Iterate over the results and get back the text
346
output_text = []
347
for result in results:
348
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
349
output_text.append(result)
350
return output_text
351
352
353
# A callback class to output a few transcriptions during training
354
class CallbackEval(keras.callbacks.Callback):
355
"""Displays a batch of outputs after every epoch."""
356
357
def __init__(self, dataset):
358
super().__init__()
359
self.dataset = dataset
360
361
def on_epoch_end(self, epoch: int, logs=None):
362
predictions = []
363
targets = []
364
for batch in self.dataset:
365
X, y = batch
366
batch_predictions = model.predict(X)
367
batch_predictions = decode_batch_predictions(batch_predictions)
368
predictions.extend(batch_predictions)
369
for label in y:
370
label = (
371
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
372
)
373
targets.append(label)
374
wer_score = wer(targets, predictions)
375
print("-" * 100)
376
print(f"Word Error Rate: {wer_score:.4f}")
377
print("-" * 100)
378
for i in np.random.randint(0, len(predictions), 2):
379
print(f"Target : {targets[i]}")
380
print(f"Prediction: {predictions[i]}")
381
print("-" * 100)
382
383
384
"""
385
Let's start the training process.
386
"""
387
388
# Define the number of epochs.
389
epochs = 1
390
# Callback function to check transcription on the val set.
391
validation_callback = CallbackEval(validation_dataset)
392
# Train the model
393
history = model.fit(
394
train_dataset,
395
validation_data=validation_dataset,
396
epochs=epochs,
397
callbacks=[validation_callback],
398
)
399
400
401
"""
402
## Inference
403
"""
404
405
# Let's check results on more validation samples
406
predictions = []
407
targets = []
408
for batch in validation_dataset:
409
X, y = batch
410
batch_predictions = model.predict(X)
411
batch_predictions = decode_batch_predictions(batch_predictions)
412
predictions.extend(batch_predictions)
413
for label in y:
414
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
415
targets.append(label)
416
wer_score = wer(targets, predictions)
417
print("-" * 100)
418
print(f"Word Error Rate: {wer_score:.4f}")
419
print("-" * 100)
420
for i in np.random.randint(0, len(predictions), 5):
421
print(f"Target : {targets[i]}")
422
print(f"Prediction: {predictions[i]}")
423
print("-" * 100)
424
425
426
"""
427
## Conclusion
428
429
In practice, you should train for around 50 epochs or more. Each epoch
430
takes approximately 5-6mn using a `GeForce RTX 2080 Ti` GPU.
431
The model we trained at 50 epochs has a `Word Error Rate (WER) β‰ˆ 16% to 17%`.
432
433
Some of the transcriptions around epoch 50:
434
435
**Audio file: LJ017-0009.wav**
436
```
437
- Target : sir thomas overbury was undoubtedly poisoned by lord rochester in the reign
438
of james the first
439
- Prediction: cer thomas overbery was undoubtedly poisoned by lordrochester in the reign
440
of james the first
441
```
442
443
**Audio file: LJ003-0340.wav**
444
```
445
- Target : the committee does not seem to have yet understood that newgate could be
446
only and properly replaced
447
- Prediction: the committee does not seem to have yet understood that newgate could be
448
only and proberly replace
449
```
450
451
**Audio file: LJ011-0136.wav**
452
```
453
- Target : still no sentence of death was carried out for the offense and in eighteen
454
thirtytwo
455
- Prediction: still no sentence of death was carried out for the offense and in eighteen
456
thirtytwo
457
```
458
459
Example available on HuggingFace.
460
| Trained Model | Demo |
461
| :--: | :--: |
462
| [![Generic badge](https://img.shields.io/badge/πŸ€—%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/πŸ€—%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) |
463
464
"""
465
466