Path: blob/master/examples/audio/transformer_asr.py
3507 views
"""1Title: Automatic Speech Recognition with Transformer2Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)3Date created: 2021/01/134Last modified: 2021/01/135Description: Training a sequence-to-sequence Transformer for automatic speech recognition.6Accelerator: GPU7"""89"""10## Introduction1112Automatic speech recognition (ASR) consists of transcribing audio speech segments into text.13ASR can be treated as a sequence-to-sequence problem, where the14audio can be represented as a sequence of feature vectors15and the text as a sequence of characters, words, or subword tokens.1617For this demonstration, we will use the LJSpeech dataset from the18[LibriVox](https://librivox.org/) project. It consists of short19audio clips of a single speaker reading passages from 7 non-fiction books.20Our model will be similar to the original Transformer (both encoder and decoder)21as proposed in the paper, "Attention is All You Need".222324**References:**2526- [Attention is All You Need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)27- [Very Deep Self-Attention Networks for End-to-End Speech Recognition](https://arxiv.org/abs/1904.13377)28- [Speech Transformers](https://ieeexplore.ieee.org/document/8462506)29- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)30"""3132import re33import os3435os.environ["KERAS_BACKEND"] = "tensorflow"3637from glob import glob38import tensorflow as tf39import keras40from keras import layers414243"""44## Define the Transformer Input Layer4546When processing past target tokens for the decoder, we compute the sum of47position embeddings and token embeddings.4849When processing audio features, we apply convolutional layers to downsample50them (via convolution strides) and process local relationships.51"""525354class TokenEmbedding(layers.Layer):55def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):56super().__init__()57self.emb = keras.layers.Embedding(num_vocab, num_hid)58self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid)5960def call(self, x):61maxlen = tf.shape(x)[-1]62x = self.emb(x)63positions = tf.range(start=0, limit=maxlen, delta=1)64positions = self.pos_emb(positions)65return x + positions666768class SpeechFeatureEmbedding(layers.Layer):69def __init__(self, num_hid=64, maxlen=100):70super().__init__()71self.conv1 = keras.layers.Conv1D(72num_hid, 11, strides=2, padding="same", activation="relu"73)74self.conv2 = keras.layers.Conv1D(75num_hid, 11, strides=2, padding="same", activation="relu"76)77self.conv3 = keras.layers.Conv1D(78num_hid, 11, strides=2, padding="same", activation="relu"79)8081def call(self, x):82x = self.conv1(x)83x = self.conv2(x)84return self.conv3(x)858687"""88## Transformer Encoder Layer89"""909192class TransformerEncoder(layers.Layer):93def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):94super().__init__()95self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)96self.ffn = keras.Sequential(97[98layers.Dense(feed_forward_dim, activation="relu"),99layers.Dense(embed_dim),100]101)102self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)103self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)104self.dropout1 = layers.Dropout(rate)105self.dropout2 = layers.Dropout(rate)106107def call(self, inputs, training=False):108attn_output = self.att(inputs, inputs)109attn_output = self.dropout1(attn_output, training=training)110out1 = self.layernorm1(inputs + attn_output)111ffn_output = self.ffn(out1)112ffn_output = self.dropout2(ffn_output, training=training)113return self.layernorm2(out1 + ffn_output)114115116"""117## Transformer Decoder Layer118"""119120121class TransformerDecoder(layers.Layer):122def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):123super().__init__()124self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)125self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)126self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)127self.self_att = layers.MultiHeadAttention(128num_heads=num_heads, key_dim=embed_dim129)130self.enc_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)131self.self_dropout = layers.Dropout(0.5)132self.enc_dropout = layers.Dropout(0.1)133self.ffn_dropout = layers.Dropout(0.1)134self.ffn = keras.Sequential(135[136layers.Dense(feed_forward_dim, activation="relu"),137layers.Dense(embed_dim),138]139)140141def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):142"""Masks the upper half of the dot product matrix in self attention.143144This prevents flow of information from future tokens to current token.1451's in the lower triangle, counting from the lower right corner.146"""147i = tf.range(n_dest)[:, None]148j = tf.range(n_src)149m = i >= j - n_src + n_dest150mask = tf.cast(m, dtype)151mask = tf.reshape(mask, [1, n_dest, n_src])152mult = tf.concat(153[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0154)155return tf.tile(mask, mult)156157def call(self, enc_out, target):158input_shape = tf.shape(target)159batch_size = input_shape[0]160seq_len = input_shape[1]161causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)162target_att = self.self_att(target, target, attention_mask=causal_mask)163target_norm = self.layernorm1(target + self.self_dropout(target_att))164enc_out = self.enc_att(target_norm, enc_out)165enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)166ffn_out = self.ffn(enc_out_norm)167ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))168return ffn_out_norm169170171"""172## Complete the Transformer model173174Our model takes audio spectrograms as inputs and predicts a sequence of characters.175During training, we give the decoder the target character sequence shifted to the left176as input. During inference, the decoder uses its own past predictions to predict the177next token.178"""179180181class Transformer(keras.Model):182def __init__(183self,184num_hid=64,185num_head=2,186num_feed_forward=128,187source_maxlen=100,188target_maxlen=100,189num_layers_enc=4,190num_layers_dec=1,191num_classes=10,192):193super().__init__()194self.loss_metric = keras.metrics.Mean(name="loss")195self.num_layers_enc = num_layers_enc196self.num_layers_dec = num_layers_dec197self.target_maxlen = target_maxlen198self.num_classes = num_classes199200self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen)201self.dec_input = TokenEmbedding(202num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid203)204205self.encoder = keras.Sequential(206[self.enc_input]207+ [208TransformerEncoder(num_hid, num_head, num_feed_forward)209for _ in range(num_layers_enc)210]211)212213for i in range(num_layers_dec):214setattr(215self,216f"dec_layer_{i}",217TransformerDecoder(num_hid, num_head, num_feed_forward),218)219220self.classifier = layers.Dense(num_classes)221222def decode(self, enc_out, target):223y = self.dec_input(target)224for i in range(self.num_layers_dec):225y = getattr(self, f"dec_layer_{i}")(enc_out, y)226return y227228def call(self, inputs):229source = inputs[0]230target = inputs[1]231x = self.encoder(source)232y = self.decode(x, target)233return self.classifier(y)234235@property236def metrics(self):237return [self.loss_metric]238239def train_step(self, batch):240"""Processes one batch inside model.fit()."""241source = batch["source"]242target = batch["target"]243dec_input = target[:, :-1]244dec_target = target[:, 1:]245with tf.GradientTape() as tape:246preds = self([source, dec_input])247one_hot = tf.one_hot(dec_target, depth=self.num_classes)248mask = tf.math.logical_not(tf.math.equal(dec_target, 0))249loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)250trainable_vars = self.trainable_variables251gradients = tape.gradient(loss, trainable_vars)252self.optimizer.apply_gradients(zip(gradients, trainable_vars))253self.loss_metric.update_state(loss)254return {"loss": self.loss_metric.result()}255256def test_step(self, batch):257source = batch["source"]258target = batch["target"]259dec_input = target[:, :-1]260dec_target = target[:, 1:]261preds = self([source, dec_input])262one_hot = tf.one_hot(dec_target, depth=self.num_classes)263mask = tf.math.logical_not(tf.math.equal(dec_target, 0))264loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)265self.loss_metric.update_state(loss)266return {"loss": self.loss_metric.result()}267268def generate(self, source, target_start_token_idx):269"""Performs inference over one batch of inputs using greedy decoding."""270bs = tf.shape(source)[0]271enc = self.encoder(source)272dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx273dec_logits = []274for i in range(self.target_maxlen - 1):275dec_out = self.decode(enc, dec_input)276logits = self.classifier(dec_out)277logits = tf.argmax(logits, axis=-1, output_type=tf.int32)278last_logit = tf.expand_dims(logits[:, -1], axis=-1)279dec_logits.append(last_logit)280dec_input = tf.concat([dec_input, last_logit], axis=-1)281return dec_input282283284"""285## Download the dataset286287Note: This requires ~3.6 GB of disk space and288takes ~5 minutes for the extraction of files.289"""290291pattern_wav_name = re.compile(r"([^/\\\.]+)")292293keras.utils.get_file(294os.path.join(os.getcwd(), "data.tar.gz"),295"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",296extract=True,297archive_format="tar",298cache_dir=".",299)300301302saveto = "./datasets/LJSpeech-1.1"303wavs = glob("{}/**/*.wav".format(saveto), recursive=True)304305id_to_text = {}306with open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") as f:307for line in f:308id = line.strip().split("|")[0]309text = line.strip().split("|")[2]310id_to_text[id] = text311312313def get_data(wavs, id_to_text, maxlen=50):314"""returns mapping of audio paths and transcription texts"""315data = []316for w in wavs:317id = pattern_wav_name.split(w)[-4]318if len(id_to_text[id]) < maxlen:319data.append({"audio": w, "text": id_to_text[id]})320return data321322323"""324## Preprocess the dataset325"""326327328class VectorizeChar:329def __init__(self, max_len=50):330self.vocab = (331["-", "#", "<", ">"]332+ [chr(i + 96) for i in range(1, 27)]333+ [" ", ".", ",", "?"]334)335self.max_len = max_len336self.char_to_idx = {}337for i, ch in enumerate(self.vocab):338self.char_to_idx[ch] = i339340def __call__(self, text):341text = text.lower()342text = text[: self.max_len - 2]343text = "<" + text + ">"344pad_len = self.max_len - len(text)345return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len346347def get_vocabulary(self):348return self.vocab349350351max_target_len = 200 # all transcripts in out data are < 200 characters352data = get_data(wavs, id_to_text, max_target_len)353vectorizer = VectorizeChar(max_target_len)354print("vocab size", len(vectorizer.get_vocabulary()))355356357def create_text_ds(data):358texts = [_["text"] for _ in data]359text_ds = [vectorizer(t) for t in texts]360text_ds = tf.data.Dataset.from_tensor_slices(text_ds)361return text_ds362363364def path_to_audio(path):365# spectrogram using stft366audio = tf.io.read_file(path)367audio, _ = tf.audio.decode_wav(audio, 1)368audio = tf.squeeze(audio, axis=-1)369stfts = tf.signal.stft(audio, frame_length=200, frame_step=80, fft_length=256)370x = tf.math.pow(tf.abs(stfts), 0.5)371# normalisation372means = tf.math.reduce_mean(x, 1, keepdims=True)373stddevs = tf.math.reduce_std(x, 1, keepdims=True)374x = (x - means) / stddevs375audio_len = tf.shape(x)[0]376# padding to 10 seconds377pad_len = 2754378paddings = tf.constant([[0, pad_len], [0, 0]])379x = tf.pad(x, paddings, "CONSTANT")[:pad_len, :]380return x381382383def create_audio_ds(data):384flist = [_["audio"] for _ in data]385audio_ds = tf.data.Dataset.from_tensor_slices(flist)386audio_ds = audio_ds.map(path_to_audio, num_parallel_calls=tf.data.AUTOTUNE)387return audio_ds388389390def create_tf_dataset(data, bs=4):391audio_ds = create_audio_ds(data)392text_ds = create_text_ds(data)393ds = tf.data.Dataset.zip((audio_ds, text_ds))394ds = ds.map(lambda x, y: {"source": x, "target": y})395ds = ds.batch(bs)396ds = ds.prefetch(tf.data.AUTOTUNE)397return ds398399400split = int(len(data) * 0.99)401train_data = data[:split]402test_data = data[split:]403ds = create_tf_dataset(train_data, bs=64)404val_ds = create_tf_dataset(test_data, bs=4)405406"""407## Callbacks to display predictions408"""409410411class DisplayOutputs(keras.callbacks.Callback):412def __init__(413self, batch, idx_to_token, target_start_token_idx=27, target_end_token_idx=28414):415"""Displays a batch of outputs after every epoch416417Args:418batch: A test batch containing the keys "source" and "target"419idx_to_token: A List containing the vocabulary tokens corresponding to their indices420target_start_token_idx: A start token index in the target vocabulary421target_end_token_idx: An end token index in the target vocabulary422"""423self.batch = batch424self.target_start_token_idx = target_start_token_idx425self.target_end_token_idx = target_end_token_idx426self.idx_to_char = idx_to_token427428def on_epoch_end(self, epoch, logs=None):429if epoch % 5 != 0:430return431source = self.batch["source"]432target = self.batch["target"].numpy()433bs = tf.shape(source)[0]434preds = self.model.generate(source, self.target_start_token_idx)435preds = preds.numpy()436for i in range(bs):437target_text = "".join([self.idx_to_char[_] for _ in target[i, :]])438prediction = ""439for idx in preds[i, :]:440prediction += self.idx_to_char[idx]441if idx == self.target_end_token_idx:442break443print(f"target: {target_text.replace('-','')}")444print(f"prediction: {prediction}\n")445446447"""448## Learning rate schedule449"""450451452class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):453def __init__(454self,455init_lr=0.00001,456lr_after_warmup=0.001,457final_lr=0.00001,458warmup_epochs=15,459decay_epochs=85,460steps_per_epoch=203,461):462super().__init__()463self.init_lr = init_lr464self.lr_after_warmup = lr_after_warmup465self.final_lr = final_lr466self.warmup_epochs = warmup_epochs467self.decay_epochs = decay_epochs468self.steps_per_epoch = steps_per_epoch469470def calculate_lr(self, epoch):471"""linear warm up - linear decay"""472warmup_lr = (473self.init_lr474+ ((self.lr_after_warmup - self.init_lr) / (self.warmup_epochs - 1)) * epoch475)476decay_lr = tf.math.maximum(477self.final_lr,478self.lr_after_warmup479- (epoch - self.warmup_epochs)480* (self.lr_after_warmup - self.final_lr)481/ self.decay_epochs,482)483return tf.math.minimum(warmup_lr, decay_lr)484485def __call__(self, step):486epoch = step // self.steps_per_epoch487epoch = tf.cast(epoch, "float32")488return self.calculate_lr(epoch)489490491"""492## Create & train the end-to-end model493"""494495batch = next(iter(val_ds))496497# The vocabulary to convert predicted indices into characters498idx_to_char = vectorizer.get_vocabulary()499display_cb = DisplayOutputs(500batch, idx_to_char, target_start_token_idx=2, target_end_token_idx=3501) # set the arguments as per vocabulary index for '<' and '>'502503model = Transformer(504num_hid=200,505num_head=2,506num_feed_forward=400,507target_maxlen=max_target_len,508num_layers_enc=4,509num_layers_dec=1,510num_classes=34,511)512loss_fn = keras.losses.CategoricalCrossentropy(513from_logits=True,514label_smoothing=0.1,515)516517learning_rate = CustomSchedule(518init_lr=0.00001,519lr_after_warmup=0.001,520final_lr=0.00001,521warmup_epochs=15,522decay_epochs=85,523steps_per_epoch=len(ds),524)525optimizer = keras.optimizers.Adam(learning_rate)526model.compile(optimizer=optimizer, loss=loss_fn)527528history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)529530"""531In practice, you should train for around 100 epochs or more.532533Some of the predicted text at or around epoch 35 may look as follows:534```535target: <as they sat in the car, frazier asked oswald where his lunch was>536prediction: <as they sat in the car frazier his lunch ware mis lunch was>537538target: <under the entry for may one, nineteen sixty,>539prediction: <under the introus for may monee, nin the sixty,>540```541"""542543544