Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/Sequence Models/Week 3/Machine Translation/nmt_utils.py
Views: 13377
import numpy as np1from faker import Faker2import random3from tqdm import tqdm4from babel.dates import format_date5from keras.utils import to_categorical6import keras.backend as K7import matplotlib.pyplot as plt89fake = Faker()10fake.seed(12345)11random.seed(12345)1213# Define format of the data we would like to generate14FORMATS = ['short',15'medium',16'long',17'full',18'full',19'full',20'full',21'full',22'full',23'full',24'full',25'full',26'full',27'd MMM YYY',28'd MMMM YYY',29'dd MMM YYY',30'd MMM, YYY',31'd MMMM, YYY',32'dd, MMM YYY',33'd MM YY',34'd MMMM YYY',35'MMMM d YYY',36'MMMM d, YYY',37'dd.MM.YY']3839# change this if you want it to work with another language40LOCALES = ['en_US']4142def load_date():43"""44Loads some fake dates45:returns: tuple containing human readable string, machine readable string, and date object46"""47dt = fake.date_object()4849try:50human_readable = format_date(dt, format=random.choice(FORMATS), locale='en_US') # locale=random.choice(LOCALES))51human_readable = human_readable.lower()52human_readable = human_readable.replace(',','')53machine_readable = dt.isoformat()5455except AttributeError as e:56return None, None, None5758return human_readable, machine_readable, dt5960def load_dataset(m):61"""62Loads a dataset with m examples and vocabularies63:m: the number of examples to generate64"""6566human_vocab = set()67machine_vocab = set()68dataset = []69Tx = 30707172for i in tqdm(range(m)):73h, m, _ = load_date()74if h is not None:75dataset.append((h, m))76human_vocab.update(tuple(h))77machine_vocab.update(tuple(m))7879human = dict(zip(sorted(human_vocab) + ['<unk>', '<pad>'],80list(range(len(human_vocab) + 2))))81inv_machine = dict(enumerate(sorted(machine_vocab)))82machine = {v:k for k,v in inv_machine.items()}8384return dataset, human, machine, inv_machine8586def preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty):8788X, Y = zip(*dataset)8990X = np.array([string_to_int(i, Tx, human_vocab) for i in X])91Y = [string_to_int(t, Ty, machine_vocab) for t in Y]9293Xoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), X)))94Yoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(machine_vocab)), Y)))9596return X, np.array(Y), Xoh, Yoh9798def string_to_int(string, length, vocab):99"""100Converts all strings in the vocabulary into a list of integers representing the positions of the101input string's characters in the "vocab"102103Arguments:104string -- input string, e.g. 'Wed 10 Jul 2007'105length -- the number of time steps you'd like, determines if the output will be padded or cut106vocab -- vocabulary, dictionary used to index every character of your "string"107108Returns:109rep -- list of integers (or '<unk>') (size = length) representing the position of the string's character in the vocabulary110"""111112#make lower to standardize113string = string.lower()114string = string.replace(',','')115116if len(string) > length:117string = string[:length]118119rep = list(map(lambda x: vocab.get(x, '<unk>'), string))120121if len(string) < length:122rep += [vocab['<pad>']] * (length - len(string))123124#print (rep)125return rep126127128def int_to_string(ints, inv_vocab):129"""130Output a machine readable list of characters based on a list of indexes in the machine's vocabulary131132Arguments:133ints -- list of integers representing indexes in the machine's vocabulary134inv_vocab -- dictionary mapping machine readable indexes to machine readable characters135136Returns:137l -- list of characters corresponding to the indexes of ints thanks to the inv_vocab mapping138"""139140l = [inv_vocab[i] for i in ints]141return l142143144EXAMPLES = ['3 May 1979', '5 Apr 09', '20th February 2016', 'Wed 10 Jul 2007']145146def run_example(model, input_vocabulary, inv_output_vocabulary, text):147encoded = string_to_int(text, TIME_STEPS, input_vocabulary)148prediction = model.predict(np.array([encoded]))149prediction = np.argmax(prediction[0], axis=-1)150return int_to_string(prediction, inv_output_vocabulary)151152def run_examples(model, input_vocabulary, inv_output_vocabulary, examples=EXAMPLES):153predicted = []154for example in examples:155predicted.append(''.join(run_example(model, input_vocabulary, inv_output_vocabulary, example)))156print('input:', example)157print('output:', predicted[-1])158return predicted159160161def softmax(x, axis=1):162"""Softmax activation function.163# Arguments164x : Tensor.165axis: Integer, axis along which the softmax normalization is applied.166# Returns167Tensor, output of softmax transformation.168# Raises169ValueError: In case `dim(x) == 1`.170"""171ndim = K.ndim(x)172if ndim == 2:173return K.softmax(x)174elif ndim > 2:175e = K.exp(x - K.max(x, axis=axis, keepdims=True))176s = K.sum(e, axis=axis, keepdims=True)177return e / s178else:179raise ValueError('Cannot apply softmax to a tensor that is 1D')180181182def plot_attention_map(model, input_vocabulary, inv_output_vocabulary, text, n_s = 128, num = 6, Tx = 30, Ty = 10):183"""184Plot the attention map.185186"""187attention_map = np.zeros((10, 30))188Ty, Tx = attention_map.shape189190s0 = np.zeros((1, n_s))191c0 = np.zeros((1, n_s))192layer = model.layers[num]193194encoded = np.array(string_to_int(text, Tx, input_vocabulary)).reshape((1, 30))195encoded = np.array(list(map(lambda x: to_categorical(x, num_classes=len(input_vocabulary)), encoded)))196197f = K.function(model.inputs, [layer.get_output_at(t) for t in range(Ty)])198r = f([encoded, s0, c0])199200for t in range(Ty):201for t_prime in range(Tx):202attention_map[t][t_prime] = r[t][0,t_prime,0]203204# Normalize attention map205# row_max = attention_map.max(axis=1)206# attention_map = attention_map / row_max[:, None]207208prediction = model.predict([encoded, s0, c0])209210predicted_text = []211for i in range(len(prediction)):212predicted_text.append(int(np.argmax(prediction[i], axis=1)))213214predicted_text = list(predicted_text)215predicted_text = int_to_string(predicted_text, inv_output_vocabulary)216text_ = list(text)217218# get the lengths of the string219input_length = len(text)220output_length = Ty221222# Plot the attention_map223plt.clf()224f = plt.figure(figsize=(8, 8.5))225ax = f.add_subplot(1, 1, 1)226227# add image228i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')229230# add colorbar231cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])232cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')233cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)234235# add labels236ax.set_yticks(range(output_length))237ax.set_yticklabels(predicted_text[:output_length])238239ax.set_xticks(range(input_length))240ax.set_xticklabels(text_[:input_length], rotation=45)241242ax.set_xlabel('Input Sequence')243ax.set_ylabel('Output Sequence')244245# add grid and legend246ax.grid()247248#f.show()249250return attention_map251252