Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/Natural Language Processing with Attention Models/Week 1 - Neural Machine Translation/w1_unittest.py
Views: 13373
import numpy as np12import trax3from trax import layers as tl4from trax.fastmath import numpy as fastnp5from trax.supervised import training67VOCAB_FILE = 'ende_32k.subword'8VOCAB_DIR = 'data/'91011def jaccard_similarity(candidate, reference):12"""Returns the Jaccard similarity between two token lists1314Args:15candidate (list of int): tokenized version of the candidate translation16reference (list of int): tokenized version of the reference translation1718Returns:19float: overlap between the two token lists20"""2122# convert the lists to a set to get the unique tokens23can_unigram_set, ref_unigram_set = set(candidate), set(reference)2425# get the set of tokens common to both candidate and reference26joint_elems = can_unigram_set.intersection(ref_unigram_set)2728# get the set of all tokens found in either candidate or reference29all_elems = can_unigram_set.union(ref_unigram_set)3031# divide the number of joint elements by the number of all elements32overlap = len(joint_elems) / len(all_elems)3334return overlap353637def weighted_avg_overlap(similarity_fn, samples, log_probs):38"""Returns the weighted mean of each candidate sentence in the samples3940Args:41samples (list of lists): tokenized version of the translated sentences42log_probs (list of float): log probability of the translated sentences4344Returns:45dict: scores of each sample46key: index of the sample47value: score of the sample48"""4950# initialize dictionary51scores = {}5253# run a for loop for each sample54for index_candidate, candidate in enumerate(samples):5556# initialize overlap and weighted sum57overlap, weight_sum = 0.0, 0.05859# run a for loop for each sample60for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):6162# skip if the candidate index is the same as the sample index63if index_candidate == index_sample:64continue6566# convert log probability to linear scale67sample_p = float(np.exp(logp))6869# update the weighted sum70weight_sum += sample_p7172# get the unigram overlap between candidate and sample73sample_overlap = similarity_fn(candidate, sample)7475# update the overlap76overlap += sample_p * sample_overlap7778# get the score for the candidate79score = overlap / weight_sum8081# save the score in the dictionary. use index as the key.82scores[index_candidate] = score8384return scores858687# UNIT TEST for UNQ_C188def test_input_encoder_fn(input_encoder_fn):89target = input_encoder_fn90success = 091fails = 09293input_vocab_size = 1094d_model = 295n_encoder_layers = 69697encoder = target(input_vocab_size, d_model, n_encoder_layers)9899lstms = "\n".join([f' LSTM_{d_model}'] * n_encoder_layers)100101expected = f"Serial[\n Embedding_{input_vocab_size}_{d_model}\n{lstms}\n]"102103proposed = str(encoder)104105# Test all layers are in the expected sequence106try:107assert(proposed.replace(" ", "") == expected.replace(" ", ""))108success += 1109except:110fails += 1111print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)112113# Test the output type114try:115assert(isinstance(encoder, trax.layers.combinators.Serial))116success += 1117# Test the number of layers118try:119# Test120assert len(encoder.sublayers) == (n_encoder_layers + 1)121success += 1122except:123fails += 1124print('The number of sublayers does not match %s <>' %len(encoder.sublayers), " %s" %(n_encoder_layers + 1))125except:126fails += 1127print("The enconder is not an object of ", trax.layers.combinators.Serial)128129130if fails == 0:131print("\033[92m All tests passed")132else:133print('\033[92m', success," Tests passed")134print('\033[91m', fails, " Tests failed")135136137# UNIT TEST for UNQ_C2138def test_pre_attention_decoder_fn(pre_attention_decoder_fn):139target = pre_attention_decoder_fn140success = 0141fails = 0142143mode = 'train'144target_vocab_size = 10145d_model = 2146147decoder = target(mode, target_vocab_size, d_model)148149expected = f"Serial[\n ShiftRight(1)\n Embedding_{target_vocab_size}_{d_model}\n LSTM_{d_model}\n]"150151proposed = str(decoder)152153# Test all layers are in the expected sequence154try:155assert(proposed.replace(" ", "") == expected.replace(" ", ""))156success += 1157except:158fails += 1159print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)160161# Test the output type162try:163assert(isinstance(decoder, trax.layers.combinators.Serial))164success += 1165# Test the number of layers166try:167# Test168assert len(decoder.sublayers) == 3169success += 1170except:171fails += 1172print('The number of sublayers does not match %s <>' %len(decoder.sublayers), " %s" %3)173except:174fails += 1175print("The enconder is not an object of ", trax.layers.combinators.Serial)176177178if fails == 0:179print("\033[92m All tests passed")180else:181print('\033[92m', success," Tests passed")182print('\033[91m', fails, " Tests failed")183184185# UNIT TEST for UNQ_C3186def test_prepare_attention_input(prepare_attention_input):187target = prepare_attention_input188success = 0189fails = 0190191#This unit test consider a batch size = 2, number_of_tokens = 3 and embedding_size = 4192193enc_act = fastnp.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],194[[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 0]]])195dec_act = fastnp.array([[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0]],196[[2, 0, 2, 0], [0, 2, 0, 2], [0, 0, 0, 0]]])197inputs = fastnp.array([[1, 2, 3], [1, 4, 0]])198199exp_mask = fastnp.array([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]],200[[[1., 1., 0.], [1., 1., 0.], [1., 1., 0.]]]])201202exp_type = type(enc_act)203204queries, keys, values, mask = target(enc_act, dec_act, inputs)205206try:207assert(fastnp.allclose(queries, dec_act))208success += 1209except:210fails += 1211print("Queries does not match the decoder activations")212try:213assert(fastnp.allclose(keys, enc_act))214success += 1215except:216fails += 1217print("Keys does not match the encoder activations")218try:219assert(fastnp.allclose(values, enc_act))220success += 1221except:222fails += 1223print("Values does not match the encoder activations")224try:225assert(fastnp.allclose(mask, exp_mask))226success += 1227except:228fails += 1229print("Mask does not match expected tensor. \nExpected:\n%s" %exp_mask, "\nOutput:\n%s" %mask)230231# Test the output type232try:233assert(isinstance(queries, exp_type))234assert(isinstance(keys, exp_type))235assert(isinstance(values, exp_type))236assert(isinstance(mask, exp_type))237success += 1238except:239fails += 1240print("One of the output object are not of type ", jax.interpreters.xla.DeviceArray)241242if fails == 0:243print("\033[92m All tests passed")244else:245print('\033[92m', success," Tests passed")246print('\033[91m', fails, " Tests failed")247248249# UNIT TEST for UNQ_C4250def test_NMTAttn(NMTAttn):251test_cases = [252{253"name":"simple_test_check",254"expected":"Serial_in2_out2[\n Select[0,1,0,1]_in2_out4\n Parallel_in2_out2[\n Serial[\n Embedding_33300_1024\n LSTM_1024\n LSTM_1024\n ]\n Serial[\n ShiftRight(1)\n Embedding_33300_1024\n LSTM_1024\n ]\n ]\n PrepareAttentionInput_in3_out4\n Serial_in4_out2[\n Branch_in4_out3[\n None\n Serial_in4_out2[\n Parallel_in3_out3[\n Dense_1024\n Dense_1024\n Dense_1024\n ]\n PureAttention_in4_out2\n Dense_1024\n ]\n ]\n Add_in2\n ]\n Select[0,2]_in3_out2\n LSTM_1024\n LSTM_1024\n Dense_33300\n LogSoftmax\n]",255"error":"The NMTAttn is not defined properly."256},257{258"name":"layer_len_check",259"expected":9,260"error":"We found {} layers in your model. It should be 9.\nCheck the LSTM stack before the dense layer"261},262{263"name":"selection_layer_check",264"expected":["Select[0,1,0,1]_in2_out4", "Select[0,2]_in3_out2"],265"error":"Look at your selection layers."266}267]268269success = 0270fails = 0271272for test_case in test_cases:273try:274if test_case['name'] == "simple_test_check":275assert test_case["expected"] == str(NMTAttn())276success += 1277if test_case['name'] == "layer_len_check":278if test_case["expected"] == len(NMTAttn().sublayers):279success += 1280else:281print(test_case["error"].format(len(NMTAttn().sublayers)))282fails += 1283if test_case['name'] == "selection_layer_check":284model = NMTAttn()285output = [str(model.sublayers[0]),str(model.sublayers[4])]286check_count = 0287for i in range(2):288if test_case["expected"][i] != output[i]:289print(test_case["error"])290fails += 1291break292else:293check_count += 1294if check_count == 2:295success += 1296except:297print(test_case['error'])298fails += 1299300if fails == 0:301print("\033[92m All tests passed")302else:303print('\033[92m', success," Tests passed")304print('\033[91m', fails, " Tests failed")305306307# UNIT TEST for UNQ_C5308def test_train_task(train_task):309target = train_task310success = 0311fails = 0312313# Test the labeled data parameter314try:315strlabel = str(target._labeled_data)316assert(strlabel.find("generator") and strlabel.find('add_loss_weights'))317success += 1318except:319fails += 1320print("Wrong labeled data parameter")321322# Test the cross entropy loss data parameter323try:324strlabel = str(target._loss_layer)325assert(strlabel == "CrossEntropyLoss_in3")326success += 1327except:328fails += 1329print("Wrong loss functions. CrossEntropyLoss_in3 was expected")330331# Test the optimizer parameter332try:333assert(isinstance(target.optimizer, trax.optimizers.adam.Adam))334success += 1335except:336fails += 1337print("Wrong optimizer")338339# Test the schedule parameter340try:341assert(isinstance(target._lr_schedule,trax.supervised.lr_schedules._BodyAndTail))342success += 1343except:344fails += 1345print("Wrong learning rate schedule type")346347# Test the _n_steps_per_checkpoint parameter348try:349assert(target._n_steps_per_checkpoint==10)350success += 1351except:352fails += 1353print("Wrong checkpoint step frequency")354355if fails == 0:356print("\033[92m All tests passed")357else:358print('\033[92m', success," Tests passed")359print('\033[91m', fails, " Tests failed")360361362363# UNIT TEST for UNQ_C6364def test_next_symbol(next_symbol, model):365target = next_symbol366the_model = model367success = 0368fails = 0369370tokens_en = np.array([[17332, 140, 172, 207, 1]])371372# Test the type and size of output373try:374next_de_tokens = target(the_model, tokens_en, [], 0.0)375assert(isinstance(next_de_tokens, tuple))376assert(len(next_de_tokens) == 2)377assert(type(next_de_tokens[0]) == int and type(next_de_tokens[1]) == float)378success += 1379except:380fails += 1381print("Output must be a tuple of size 2 containing a integer and a float number")382383# Test an output384try:385next_de_tokens = target(the_model, tokens_en, [18477], 0.0)386assert(np.allclose([next_de_tokens[0], next_de_tokens[1]], [140, -0.000217437744]))387success += 1388except:389fails += 1390print("Expected output: ", [140, -0.000217437744])391392393if fails == 0:394print("\033[92m All tests passed")395else:396print('\033[92m', success," Tests passed")397print('\033[91m', fails, " Tests failed")398399400# UNIT TEST for UNQ_C7401def test_sampling_decode(sampling_decode, model):402target = sampling_decode403the_model = model404success = 0405fails = 0406407try:408output = target("I eat soup.", model, temperature=0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)409expected = ([161, 15103, 5, 25132, 35, 3, 1], -0.0003108978271484375, 'Ich iss Suppe.')410assert(output[2] == expected[2])411success += 1412except:413fails += 1414print("Test 1 fails")415416try:417output = target("I like your shoes.", model, temperature=0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)418expected = ([161, 15103, 5, 25132, 35, 3, 1], -0.0003108978271484375, 'Ich mag Ihre Schuhe.')419assert(output[2] == expected[2])420success += 1421except:422fails += 1423print("Test 2 fails")424425if fails == 0:426print("\033[92m All tests passed")427else:428print('\033[92m', success," Tests passed")429print('\033[91m', fails, " Tests failed")430431432# UNIT TEST for UNQ_C8433def test_rouge1_similarity(rouge1_similarity):434target = rouge1_similarity435success = 0436fails = 0437n_samples = 10438439test_cases = [440441{442"name":"simple_test_check",443"input": [[1, 2, 3], [1, 2, 3, 4]],444"expected":0.8571428571428571,445"error":"Expected similarity: 0.8571428571428571"446},447{448"name":"simple_test_check",449"input":[[2, 1], [3, 1]],450"expected":0.5,451"error":"Expected similarity: 0.5"452},453{454"name":"simple_test_check",455"input":[[2], [3]],456"expected":0,457"error":"Expected similarity: 0"458},459{460"name":"simple_test_check",461"input":[[0] * 100 + [2] * 100, [0] * 100 + [1] * 100],462"expected":0.5,463"error":"Expected similarity: 0.5"464}465]466467for test_case in test_cases:468469try:470if test_case['name'] == "simple_test_check":471assert abs(test_case["expected"] -target(*test_case['input'])) < 1e-6472success += 1473except:474print(test_case['error'])475fails += 1476477if fails == 0:478print("\033[92m All tests passed")479else:480print('\033[92m', success," Tests passed")481print('\033[91m', fails, " Tests failed")482483484# UNIT TEST for UNQ_C9485def test_average_overlap(average_overlap):486target = average_overlap487success = 0488fails = 0489490test_cases = [491492{493"name":"dict_test_check",494"input": [jaccard_similarity, [[1, 2], [3, 4], [1, 2], [3, 5]]],495"expected":{0: 0.3333333333333333,4961: 0.1111111111111111,4972: 0.3333333333333333,4983: 0.1111111111111111},499"error":"Expected output does not match"500},501{502"name":"dict_test_check",503"input":[jaccard_similarity, [[1, 2], [3, 4], [1, 2, 5], [3, 5], [3, 4, 1]]],504"expected":{0: 0.22916666666666666,5051: 0.25,5062: 0.2791666666666667,5073: 0.20833333333333331,5084: 0.3416666666666667},509"error":"Expected output does not match"510}511]512for test_case in test_cases:513try:514if test_case['name'] == "dict_test_check":515output = target(*test_case['input'])516for x in output:517assert (abs(output[x] - test_case['expected'][x]) < 1e-5)518success += 1519except:520print(test_case['error'])521fails += 1522523if fails == 0:524print("\033[92m All tests passed")525else:526print('\033[92m', success," Tests passed")527print('\033[91m', fails, " Tests failed")528529530# UNIT TEST for UNQ_C10531def test_mbr_decode(mbr_decode, model):532target = mbr_decode533success = 0534fails = 0535536TEMPERATURE = 0.0537538test_cases = [539540{541"name":"simple_test_check",542"input": "I am hungry",543"expected":"Ich bin hungrig.",544"error":"Expected output does not match"545},546{547"name":"simple_test_check",548"input":'Congratulations!',549"expected":'Herzlichen Glückwunsch!',550"error":"Expected output does not match"551},552{553"name":"simple_test_check",554"input":'You have completed the assignment!',555"expected":'Sie haben die Abtretung abgeschlossen!',556"error":"Expected output does not match"557}558]559for test_case in test_cases:560try:561result = target(test_case['input'], 4, weighted_avg_overlap, jaccard_similarity,562model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)563564output = result[0]565if test_case['name'] == "simple_test_check":566assert(output == test_case['expected'])567success += 1568except:569print(test_case['error'])570fails += 1571572# Test that function return the most likely translation573TEMPERATURE = 0.5574test_case = test_cases[0]575try:576result = target(test_case['input'], 4, weighted_avg_overlap, jaccard_similarity,577model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)578579assert max(result[2], key=result[2].get) == result[1]580success += 1581except:582print('Use max function to select max_index')583fails += 1584585if fails == 0:586print("\033[92m All tests passed")587else:588print('\033[92m', success," Tests passed")589print('\033[91m', fails, " Tests failed")590591