Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/transformer_asr.py
3507 views
1
"""
2
Title: Automatic Speech Recognition with Transformer
3
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4
Date created: 2021/01/13
5
Last modified: 2021/01/13
6
Description: Training a sequence-to-sequence Transformer for automatic speech recognition.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Automatic speech recognition (ASR) consists of transcribing audio speech segments into text.
14
ASR can be treated as a sequence-to-sequence problem, where the
15
audio can be represented as a sequence of feature vectors
16
and the text as a sequence of characters, words, or subword tokens.
17
18
For this demonstration, we will use the LJSpeech dataset from the
19
[LibriVox](https://librivox.org/) project. It consists of short
20
audio clips of a single speaker reading passages from 7 non-fiction books.
21
Our model will be similar to the original Transformer (both encoder and decoder)
22
as proposed in the paper, "Attention is All You Need".
23
24
25
**References:**
26
27
- [Attention is All You Need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
28
- [Very Deep Self-Attention Networks for End-to-End Speech Recognition](https://arxiv.org/abs/1904.13377)
29
- [Speech Transformers](https://ieeexplore.ieee.org/document/8462506)
30
- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
31
"""
32
33
import re
34
import os
35
36
os.environ["KERAS_BACKEND"] = "tensorflow"
37
38
from glob import glob
39
import tensorflow as tf
40
import keras
41
from keras import layers
42
43
44
"""
45
## Define the Transformer Input Layer
46
47
When processing past target tokens for the decoder, we compute the sum of
48
position embeddings and token embeddings.
49
50
When processing audio features, we apply convolutional layers to downsample
51
them (via convolution strides) and process local relationships.
52
"""
53
54
55
class TokenEmbedding(layers.Layer):
56
def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):
57
super().__init__()
58
self.emb = keras.layers.Embedding(num_vocab, num_hid)
59
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid)
60
61
def call(self, x):
62
maxlen = tf.shape(x)[-1]
63
x = self.emb(x)
64
positions = tf.range(start=0, limit=maxlen, delta=1)
65
positions = self.pos_emb(positions)
66
return x + positions
67
68
69
class SpeechFeatureEmbedding(layers.Layer):
70
def __init__(self, num_hid=64, maxlen=100):
71
super().__init__()
72
self.conv1 = keras.layers.Conv1D(
73
num_hid, 11, strides=2, padding="same", activation="relu"
74
)
75
self.conv2 = keras.layers.Conv1D(
76
num_hid, 11, strides=2, padding="same", activation="relu"
77
)
78
self.conv3 = keras.layers.Conv1D(
79
num_hid, 11, strides=2, padding="same", activation="relu"
80
)
81
82
def call(self, x):
83
x = self.conv1(x)
84
x = self.conv2(x)
85
return self.conv3(x)
86
87
88
"""
89
## Transformer Encoder Layer
90
"""
91
92
93
class TransformerEncoder(layers.Layer):
94
def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
95
super().__init__()
96
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
97
self.ffn = keras.Sequential(
98
[
99
layers.Dense(feed_forward_dim, activation="relu"),
100
layers.Dense(embed_dim),
101
]
102
)
103
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
104
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
105
self.dropout1 = layers.Dropout(rate)
106
self.dropout2 = layers.Dropout(rate)
107
108
def call(self, inputs, training=False):
109
attn_output = self.att(inputs, inputs)
110
attn_output = self.dropout1(attn_output, training=training)
111
out1 = self.layernorm1(inputs + attn_output)
112
ffn_output = self.ffn(out1)
113
ffn_output = self.dropout2(ffn_output, training=training)
114
return self.layernorm2(out1 + ffn_output)
115
116
117
"""
118
## Transformer Decoder Layer
119
"""
120
121
122
class TransformerDecoder(layers.Layer):
123
def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):
124
super().__init__()
125
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
126
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
127
self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)
128
self.self_att = layers.MultiHeadAttention(
129
num_heads=num_heads, key_dim=embed_dim
130
)
131
self.enc_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
132
self.self_dropout = layers.Dropout(0.5)
133
self.enc_dropout = layers.Dropout(0.1)
134
self.ffn_dropout = layers.Dropout(0.1)
135
self.ffn = keras.Sequential(
136
[
137
layers.Dense(feed_forward_dim, activation="relu"),
138
layers.Dense(embed_dim),
139
]
140
)
141
142
def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
143
"""Masks the upper half of the dot product matrix in self attention.
144
145
This prevents flow of information from future tokens to current token.
146
1's in the lower triangle, counting from the lower right corner.
147
"""
148
i = tf.range(n_dest)[:, None]
149
j = tf.range(n_src)
150
m = i >= j - n_src + n_dest
151
mask = tf.cast(m, dtype)
152
mask = tf.reshape(mask, [1, n_dest, n_src])
153
mult = tf.concat(
154
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
155
)
156
return tf.tile(mask, mult)
157
158
def call(self, enc_out, target):
159
input_shape = tf.shape(target)
160
batch_size = input_shape[0]
161
seq_len = input_shape[1]
162
causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
163
target_att = self.self_att(target, target, attention_mask=causal_mask)
164
target_norm = self.layernorm1(target + self.self_dropout(target_att))
165
enc_out = self.enc_att(target_norm, enc_out)
166
enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)
167
ffn_out = self.ffn(enc_out_norm)
168
ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))
169
return ffn_out_norm
170
171
172
"""
173
## Complete the Transformer model
174
175
Our model takes audio spectrograms as inputs and predicts a sequence of characters.
176
During training, we give the decoder the target character sequence shifted to the left
177
as input. During inference, the decoder uses its own past predictions to predict the
178
next token.
179
"""
180
181
182
class Transformer(keras.Model):
183
def __init__(
184
self,
185
num_hid=64,
186
num_head=2,
187
num_feed_forward=128,
188
source_maxlen=100,
189
target_maxlen=100,
190
num_layers_enc=4,
191
num_layers_dec=1,
192
num_classes=10,
193
):
194
super().__init__()
195
self.loss_metric = keras.metrics.Mean(name="loss")
196
self.num_layers_enc = num_layers_enc
197
self.num_layers_dec = num_layers_dec
198
self.target_maxlen = target_maxlen
199
self.num_classes = num_classes
200
201
self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen)
202
self.dec_input = TokenEmbedding(
203
num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
204
)
205
206
self.encoder = keras.Sequential(
207
[self.enc_input]
208
+ [
209
TransformerEncoder(num_hid, num_head, num_feed_forward)
210
for _ in range(num_layers_enc)
211
]
212
)
213
214
for i in range(num_layers_dec):
215
setattr(
216
self,
217
f"dec_layer_{i}",
218
TransformerDecoder(num_hid, num_head, num_feed_forward),
219
)
220
221
self.classifier = layers.Dense(num_classes)
222
223
def decode(self, enc_out, target):
224
y = self.dec_input(target)
225
for i in range(self.num_layers_dec):
226
y = getattr(self, f"dec_layer_{i}")(enc_out, y)
227
return y
228
229
def call(self, inputs):
230
source = inputs[0]
231
target = inputs[1]
232
x = self.encoder(source)
233
y = self.decode(x, target)
234
return self.classifier(y)
235
236
@property
237
def metrics(self):
238
return [self.loss_metric]
239
240
def train_step(self, batch):
241
"""Processes one batch inside model.fit()."""
242
source = batch["source"]
243
target = batch["target"]
244
dec_input = target[:, :-1]
245
dec_target = target[:, 1:]
246
with tf.GradientTape() as tape:
247
preds = self([source, dec_input])
248
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
249
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
250
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
251
trainable_vars = self.trainable_variables
252
gradients = tape.gradient(loss, trainable_vars)
253
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
254
self.loss_metric.update_state(loss)
255
return {"loss": self.loss_metric.result()}
256
257
def test_step(self, batch):
258
source = batch["source"]
259
target = batch["target"]
260
dec_input = target[:, :-1]
261
dec_target = target[:, 1:]
262
preds = self([source, dec_input])
263
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
264
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
265
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
266
self.loss_metric.update_state(loss)
267
return {"loss": self.loss_metric.result()}
268
269
def generate(self, source, target_start_token_idx):
270
"""Performs inference over one batch of inputs using greedy decoding."""
271
bs = tf.shape(source)[0]
272
enc = self.encoder(source)
273
dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx
274
dec_logits = []
275
for i in range(self.target_maxlen - 1):
276
dec_out = self.decode(enc, dec_input)
277
logits = self.classifier(dec_out)
278
logits = tf.argmax(logits, axis=-1, output_type=tf.int32)
279
last_logit = tf.expand_dims(logits[:, -1], axis=-1)
280
dec_logits.append(last_logit)
281
dec_input = tf.concat([dec_input, last_logit], axis=-1)
282
return dec_input
283
284
285
"""
286
## Download the dataset
287
288
Note: This requires ~3.6 GB of disk space and
289
takes ~5 minutes for the extraction of files.
290
"""
291
292
pattern_wav_name = re.compile(r"([^/\\\.]+)")
293
294
keras.utils.get_file(
295
os.path.join(os.getcwd(), "data.tar.gz"),
296
"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
297
extract=True,
298
archive_format="tar",
299
cache_dir=".",
300
)
301
302
303
saveto = "./datasets/LJSpeech-1.1"
304
wavs = glob("{}/**/*.wav".format(saveto), recursive=True)
305
306
id_to_text = {}
307
with open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") as f:
308
for line in f:
309
id = line.strip().split("|")[0]
310
text = line.strip().split("|")[2]
311
id_to_text[id] = text
312
313
314
def get_data(wavs, id_to_text, maxlen=50):
315
"""returns mapping of audio paths and transcription texts"""
316
data = []
317
for w in wavs:
318
id = pattern_wav_name.split(w)[-4]
319
if len(id_to_text[id]) < maxlen:
320
data.append({"audio": w, "text": id_to_text[id]})
321
return data
322
323
324
"""
325
## Preprocess the dataset
326
"""
327
328
329
class VectorizeChar:
330
def __init__(self, max_len=50):
331
self.vocab = (
332
["-", "#", "<", ">"]
333
+ [chr(i + 96) for i in range(1, 27)]
334
+ [" ", ".", ",", "?"]
335
)
336
self.max_len = max_len
337
self.char_to_idx = {}
338
for i, ch in enumerate(self.vocab):
339
self.char_to_idx[ch] = i
340
341
def __call__(self, text):
342
text = text.lower()
343
text = text[: self.max_len - 2]
344
text = "<" + text + ">"
345
pad_len = self.max_len - len(text)
346
return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len
347
348
def get_vocabulary(self):
349
return self.vocab
350
351
352
max_target_len = 200 # all transcripts in out data are < 200 characters
353
data = get_data(wavs, id_to_text, max_target_len)
354
vectorizer = VectorizeChar(max_target_len)
355
print("vocab size", len(vectorizer.get_vocabulary()))
356
357
358
def create_text_ds(data):
359
texts = [_["text"] for _ in data]
360
text_ds = [vectorizer(t) for t in texts]
361
text_ds = tf.data.Dataset.from_tensor_slices(text_ds)
362
return text_ds
363
364
365
def path_to_audio(path):
366
# spectrogram using stft
367
audio = tf.io.read_file(path)
368
audio, _ = tf.audio.decode_wav(audio, 1)
369
audio = tf.squeeze(audio, axis=-1)
370
stfts = tf.signal.stft(audio, frame_length=200, frame_step=80, fft_length=256)
371
x = tf.math.pow(tf.abs(stfts), 0.5)
372
# normalisation
373
means = tf.math.reduce_mean(x, 1, keepdims=True)
374
stddevs = tf.math.reduce_std(x, 1, keepdims=True)
375
x = (x - means) / stddevs
376
audio_len = tf.shape(x)[0]
377
# padding to 10 seconds
378
pad_len = 2754
379
paddings = tf.constant([[0, pad_len], [0, 0]])
380
x = tf.pad(x, paddings, "CONSTANT")[:pad_len, :]
381
return x
382
383
384
def create_audio_ds(data):
385
flist = [_["audio"] for _ in data]
386
audio_ds = tf.data.Dataset.from_tensor_slices(flist)
387
audio_ds = audio_ds.map(path_to_audio, num_parallel_calls=tf.data.AUTOTUNE)
388
return audio_ds
389
390
391
def create_tf_dataset(data, bs=4):
392
audio_ds = create_audio_ds(data)
393
text_ds = create_text_ds(data)
394
ds = tf.data.Dataset.zip((audio_ds, text_ds))
395
ds = ds.map(lambda x, y: {"source": x, "target": y})
396
ds = ds.batch(bs)
397
ds = ds.prefetch(tf.data.AUTOTUNE)
398
return ds
399
400
401
split = int(len(data) * 0.99)
402
train_data = data[:split]
403
test_data = data[split:]
404
ds = create_tf_dataset(train_data, bs=64)
405
val_ds = create_tf_dataset(test_data, bs=4)
406
407
"""
408
## Callbacks to display predictions
409
"""
410
411
412
class DisplayOutputs(keras.callbacks.Callback):
413
def __init__(
414
self, batch, idx_to_token, target_start_token_idx=27, target_end_token_idx=28
415
):
416
"""Displays a batch of outputs after every epoch
417
418
Args:
419
batch: A test batch containing the keys "source" and "target"
420
idx_to_token: A List containing the vocabulary tokens corresponding to their indices
421
target_start_token_idx: A start token index in the target vocabulary
422
target_end_token_idx: An end token index in the target vocabulary
423
"""
424
self.batch = batch
425
self.target_start_token_idx = target_start_token_idx
426
self.target_end_token_idx = target_end_token_idx
427
self.idx_to_char = idx_to_token
428
429
def on_epoch_end(self, epoch, logs=None):
430
if epoch % 5 != 0:
431
return
432
source = self.batch["source"]
433
target = self.batch["target"].numpy()
434
bs = tf.shape(source)[0]
435
preds = self.model.generate(source, self.target_start_token_idx)
436
preds = preds.numpy()
437
for i in range(bs):
438
target_text = "".join([self.idx_to_char[_] for _ in target[i, :]])
439
prediction = ""
440
for idx in preds[i, :]:
441
prediction += self.idx_to_char[idx]
442
if idx == self.target_end_token_idx:
443
break
444
print(f"target: {target_text.replace('-','')}")
445
print(f"prediction: {prediction}\n")
446
447
448
"""
449
## Learning rate schedule
450
"""
451
452
453
class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):
454
def __init__(
455
self,
456
init_lr=0.00001,
457
lr_after_warmup=0.001,
458
final_lr=0.00001,
459
warmup_epochs=15,
460
decay_epochs=85,
461
steps_per_epoch=203,
462
):
463
super().__init__()
464
self.init_lr = init_lr
465
self.lr_after_warmup = lr_after_warmup
466
self.final_lr = final_lr
467
self.warmup_epochs = warmup_epochs
468
self.decay_epochs = decay_epochs
469
self.steps_per_epoch = steps_per_epoch
470
471
def calculate_lr(self, epoch):
472
"""linear warm up - linear decay"""
473
warmup_lr = (
474
self.init_lr
475
+ ((self.lr_after_warmup - self.init_lr) / (self.warmup_epochs - 1)) * epoch
476
)
477
decay_lr = tf.math.maximum(
478
self.final_lr,
479
self.lr_after_warmup
480
- (epoch - self.warmup_epochs)
481
* (self.lr_after_warmup - self.final_lr)
482
/ self.decay_epochs,
483
)
484
return tf.math.minimum(warmup_lr, decay_lr)
485
486
def __call__(self, step):
487
epoch = step // self.steps_per_epoch
488
epoch = tf.cast(epoch, "float32")
489
return self.calculate_lr(epoch)
490
491
492
"""
493
## Create & train the end-to-end model
494
"""
495
496
batch = next(iter(val_ds))
497
498
# The vocabulary to convert predicted indices into characters
499
idx_to_char = vectorizer.get_vocabulary()
500
display_cb = DisplayOutputs(
501
batch, idx_to_char, target_start_token_idx=2, target_end_token_idx=3
502
) # set the arguments as per vocabulary index for '<' and '>'
503
504
model = Transformer(
505
num_hid=200,
506
num_head=2,
507
num_feed_forward=400,
508
target_maxlen=max_target_len,
509
num_layers_enc=4,
510
num_layers_dec=1,
511
num_classes=34,
512
)
513
loss_fn = keras.losses.CategoricalCrossentropy(
514
from_logits=True,
515
label_smoothing=0.1,
516
)
517
518
learning_rate = CustomSchedule(
519
init_lr=0.00001,
520
lr_after_warmup=0.001,
521
final_lr=0.00001,
522
warmup_epochs=15,
523
decay_epochs=85,
524
steps_per_epoch=len(ds),
525
)
526
optimizer = keras.optimizers.Adam(learning_rate)
527
model.compile(optimizer=optimizer, loss=loss_fn)
528
529
history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)
530
531
"""
532
In practice, you should train for around 100 epochs or more.
533
534
Some of the predicted text at or around epoch 35 may look as follows:
535
```
536
target: <as they sat in the car, frazier asked oswald where his lunch was>
537
prediction: <as they sat in the car frazier his lunch ware mis lunch was>
538
539
target: <under the entry for may one, nineteen sixty,>
540
prediction: <under the introus for may monee, nin the sixty,>
541
```
542
"""
543
544