CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
y33-j3T

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: y33-j3T/Coursera-Deep-Learning
Path: blob/master/Natural Language Processing with Attention Models/Week 1 - Neural Machine Translation/w1_unittest.py
Views: 13373
1
import numpy as np
2
3
import trax
4
from trax import layers as tl
5
from trax.fastmath import numpy as fastnp
6
from trax.supervised import training
7
8
VOCAB_FILE = 'ende_32k.subword'
9
VOCAB_DIR = 'data/'
10
11
12
def jaccard_similarity(candidate, reference):
13
"""Returns the Jaccard similarity between two token lists
14
15
Args:
16
candidate (list of int): tokenized version of the candidate translation
17
reference (list of int): tokenized version of the reference translation
18
19
Returns:
20
float: overlap between the two token lists
21
"""
22
23
# convert the lists to a set to get the unique tokens
24
can_unigram_set, ref_unigram_set = set(candidate), set(reference)
25
26
# get the set of tokens common to both candidate and reference
27
joint_elems = can_unigram_set.intersection(ref_unigram_set)
28
29
# get the set of all tokens found in either candidate or reference
30
all_elems = can_unigram_set.union(ref_unigram_set)
31
32
# divide the number of joint elements by the number of all elements
33
overlap = len(joint_elems) / len(all_elems)
34
35
return overlap
36
37
38
def weighted_avg_overlap(similarity_fn, samples, log_probs):
39
"""Returns the weighted mean of each candidate sentence in the samples
40
41
Args:
42
samples (list of lists): tokenized version of the translated sentences
43
log_probs (list of float): log probability of the translated sentences
44
45
Returns:
46
dict: scores of each sample
47
key: index of the sample
48
value: score of the sample
49
"""
50
51
# initialize dictionary
52
scores = {}
53
54
# run a for loop for each sample
55
for index_candidate, candidate in enumerate(samples):
56
57
# initialize overlap and weighted sum
58
overlap, weight_sum = 0.0, 0.0
59
60
# run a for loop for each sample
61
for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):
62
63
# skip if the candidate index is the same as the sample index
64
if index_candidate == index_sample:
65
continue
66
67
# convert log probability to linear scale
68
sample_p = float(np.exp(logp))
69
70
# update the weighted sum
71
weight_sum += sample_p
72
73
# get the unigram overlap between candidate and sample
74
sample_overlap = similarity_fn(candidate, sample)
75
76
# update the overlap
77
overlap += sample_p * sample_overlap
78
79
# get the score for the candidate
80
score = overlap / weight_sum
81
82
# save the score in the dictionary. use index as the key.
83
scores[index_candidate] = score
84
85
return scores
86
87
88
# UNIT TEST for UNQ_C1
89
def test_input_encoder_fn(input_encoder_fn):
90
target = input_encoder_fn
91
success = 0
92
fails = 0
93
94
input_vocab_size = 10
95
d_model = 2
96
n_encoder_layers = 6
97
98
encoder = target(input_vocab_size, d_model, n_encoder_layers)
99
100
lstms = "\n".join([f' LSTM_{d_model}'] * n_encoder_layers)
101
102
expected = f"Serial[\n Embedding_{input_vocab_size}_{d_model}\n{lstms}\n]"
103
104
proposed = str(encoder)
105
106
# Test all layers are in the expected sequence
107
try:
108
assert(proposed.replace(" ", "") == expected.replace(" ", ""))
109
success += 1
110
except:
111
fails += 1
112
print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)
113
114
# Test the output type
115
try:
116
assert(isinstance(encoder, trax.layers.combinators.Serial))
117
success += 1
118
# Test the number of layers
119
try:
120
# Test
121
assert len(encoder.sublayers) == (n_encoder_layers + 1)
122
success += 1
123
except:
124
fails += 1
125
print('The number of sublayers does not match %s <>' %len(encoder.sublayers), " %s" %(n_encoder_layers + 1))
126
except:
127
fails += 1
128
print("The enconder is not an object of ", trax.layers.combinators.Serial)
129
130
131
if fails == 0:
132
print("\033[92m All tests passed")
133
else:
134
print('\033[92m', success," Tests passed")
135
print('\033[91m', fails, " Tests failed")
136
137
138
# UNIT TEST for UNQ_C2
139
def test_pre_attention_decoder_fn(pre_attention_decoder_fn):
140
target = pre_attention_decoder_fn
141
success = 0
142
fails = 0
143
144
mode = 'train'
145
target_vocab_size = 10
146
d_model = 2
147
148
decoder = target(mode, target_vocab_size, d_model)
149
150
expected = f"Serial[\n ShiftRight(1)\n Embedding_{target_vocab_size}_{d_model}\n LSTM_{d_model}\n]"
151
152
proposed = str(decoder)
153
154
# Test all layers are in the expected sequence
155
try:
156
assert(proposed.replace(" ", "") == expected.replace(" ", ""))
157
success += 1
158
except:
159
fails += 1
160
print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)
161
162
# Test the output type
163
try:
164
assert(isinstance(decoder, trax.layers.combinators.Serial))
165
success += 1
166
# Test the number of layers
167
try:
168
# Test
169
assert len(decoder.sublayers) == 3
170
success += 1
171
except:
172
fails += 1
173
print('The number of sublayers does not match %s <>' %len(decoder.sublayers), " %s" %3)
174
except:
175
fails += 1
176
print("The enconder is not an object of ", trax.layers.combinators.Serial)
177
178
179
if fails == 0:
180
print("\033[92m All tests passed")
181
else:
182
print('\033[92m', success," Tests passed")
183
print('\033[91m', fails, " Tests failed")
184
185
186
# UNIT TEST for UNQ_C3
187
def test_prepare_attention_input(prepare_attention_input):
188
target = prepare_attention_input
189
success = 0
190
fails = 0
191
192
#This unit test consider a batch size = 2, number_of_tokens = 3 and embedding_size = 4
193
194
enc_act = fastnp.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
195
[[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 0]]])
196
dec_act = fastnp.array([[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0]],
197
[[2, 0, 2, 0], [0, 2, 0, 2], [0, 0, 0, 0]]])
198
inputs = fastnp.array([[1, 2, 3], [1, 4, 0]])
199
200
exp_mask = fastnp.array([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]],
201
[[[1., 1., 0.], [1., 1., 0.], [1., 1., 0.]]]])
202
203
exp_type = type(enc_act)
204
205
queries, keys, values, mask = target(enc_act, dec_act, inputs)
206
207
try:
208
assert(fastnp.allclose(queries, dec_act))
209
success += 1
210
except:
211
fails += 1
212
print("Queries does not match the decoder activations")
213
try:
214
assert(fastnp.allclose(keys, enc_act))
215
success += 1
216
except:
217
fails += 1
218
print("Keys does not match the encoder activations")
219
try:
220
assert(fastnp.allclose(values, enc_act))
221
success += 1
222
except:
223
fails += 1
224
print("Values does not match the encoder activations")
225
try:
226
assert(fastnp.allclose(mask, exp_mask))
227
success += 1
228
except:
229
fails += 1
230
print("Mask does not match expected tensor. \nExpected:\n%s" %exp_mask, "\nOutput:\n%s" %mask)
231
232
# Test the output type
233
try:
234
assert(isinstance(queries, exp_type))
235
assert(isinstance(keys, exp_type))
236
assert(isinstance(values, exp_type))
237
assert(isinstance(mask, exp_type))
238
success += 1
239
except:
240
fails += 1
241
print("One of the output object are not of type ", jax.interpreters.xla.DeviceArray)
242
243
if fails == 0:
244
print("\033[92m All tests passed")
245
else:
246
print('\033[92m', success," Tests passed")
247
print('\033[91m', fails, " Tests failed")
248
249
250
# UNIT TEST for UNQ_C4
251
def test_NMTAttn(NMTAttn):
252
test_cases = [
253
{
254
"name":"simple_test_check",
255
"expected":"Serial_in2_out2[\n Select[0,1,0,1]_in2_out4\n Parallel_in2_out2[\n Serial[\n Embedding_33300_1024\n LSTM_1024\n LSTM_1024\n ]\n Serial[\n ShiftRight(1)\n Embedding_33300_1024\n LSTM_1024\n ]\n ]\n PrepareAttentionInput_in3_out4\n Serial_in4_out2[\n Branch_in4_out3[\n None\n Serial_in4_out2[\n Parallel_in3_out3[\n Dense_1024\n Dense_1024\n Dense_1024\n ]\n PureAttention_in4_out2\n Dense_1024\n ]\n ]\n Add_in2\n ]\n Select[0,2]_in3_out2\n LSTM_1024\n LSTM_1024\n Dense_33300\n LogSoftmax\n]",
256
"error":"The NMTAttn is not defined properly."
257
},
258
{
259
"name":"layer_len_check",
260
"expected":9,
261
"error":"We found {} layers in your model. It should be 9.\nCheck the LSTM stack before the dense layer"
262
},
263
{
264
"name":"selection_layer_check",
265
"expected":["Select[0,1,0,1]_in2_out4", "Select[0,2]_in3_out2"],
266
"error":"Look at your selection layers."
267
}
268
]
269
270
success = 0
271
fails = 0
272
273
for test_case in test_cases:
274
try:
275
if test_case['name'] == "simple_test_check":
276
assert test_case["expected"] == str(NMTAttn())
277
success += 1
278
if test_case['name'] == "layer_len_check":
279
if test_case["expected"] == len(NMTAttn().sublayers):
280
success += 1
281
else:
282
print(test_case["error"].format(len(NMTAttn().sublayers)))
283
fails += 1
284
if test_case['name'] == "selection_layer_check":
285
model = NMTAttn()
286
output = [str(model.sublayers[0]),str(model.sublayers[4])]
287
check_count = 0
288
for i in range(2):
289
if test_case["expected"][i] != output[i]:
290
print(test_case["error"])
291
fails += 1
292
break
293
else:
294
check_count += 1
295
if check_count == 2:
296
success += 1
297
except:
298
print(test_case['error'])
299
fails += 1
300
301
if fails == 0:
302
print("\033[92m All tests passed")
303
else:
304
print('\033[92m', success," Tests passed")
305
print('\033[91m', fails, " Tests failed")
306
307
308
# UNIT TEST for UNQ_C5
309
def test_train_task(train_task):
310
target = train_task
311
success = 0
312
fails = 0
313
314
# Test the labeled data parameter
315
try:
316
strlabel = str(target._labeled_data)
317
assert(strlabel.find("generator") and strlabel.find('add_loss_weights'))
318
success += 1
319
except:
320
fails += 1
321
print("Wrong labeled data parameter")
322
323
# Test the cross entropy loss data parameter
324
try:
325
strlabel = str(target._loss_layer)
326
assert(strlabel == "CrossEntropyLoss_in3")
327
success += 1
328
except:
329
fails += 1
330
print("Wrong loss functions. CrossEntropyLoss_in3 was expected")
331
332
# Test the optimizer parameter
333
try:
334
assert(isinstance(target.optimizer, trax.optimizers.adam.Adam))
335
success += 1
336
except:
337
fails += 1
338
print("Wrong optimizer")
339
340
# Test the schedule parameter
341
try:
342
assert(isinstance(target._lr_schedule,trax.supervised.lr_schedules._BodyAndTail))
343
success += 1
344
except:
345
fails += 1
346
print("Wrong learning rate schedule type")
347
348
# Test the _n_steps_per_checkpoint parameter
349
try:
350
assert(target._n_steps_per_checkpoint==10)
351
success += 1
352
except:
353
fails += 1
354
print("Wrong checkpoint step frequency")
355
356
if fails == 0:
357
print("\033[92m All tests passed")
358
else:
359
print('\033[92m', success," Tests passed")
360
print('\033[91m', fails, " Tests failed")
361
362
363
364
# UNIT TEST for UNQ_C6
365
def test_next_symbol(next_symbol, model):
366
target = next_symbol
367
the_model = model
368
success = 0
369
fails = 0
370
371
tokens_en = np.array([[17332, 140, 172, 207, 1]])
372
373
# Test the type and size of output
374
try:
375
next_de_tokens = target(the_model, tokens_en, [], 0.0)
376
assert(isinstance(next_de_tokens, tuple))
377
assert(len(next_de_tokens) == 2)
378
assert(type(next_de_tokens[0]) == int and type(next_de_tokens[1]) == float)
379
success += 1
380
except:
381
fails += 1
382
print("Output must be a tuple of size 2 containing a integer and a float number")
383
384
# Test an output
385
try:
386
next_de_tokens = target(the_model, tokens_en, [18477], 0.0)
387
assert(np.allclose([next_de_tokens[0], next_de_tokens[1]], [140, -0.000217437744]))
388
success += 1
389
except:
390
fails += 1
391
print("Expected output: ", [140, -0.000217437744])
392
393
394
if fails == 0:
395
print("\033[92m All tests passed")
396
else:
397
print('\033[92m', success," Tests passed")
398
print('\033[91m', fails, " Tests failed")
399
400
401
# UNIT TEST for UNQ_C7
402
def test_sampling_decode(sampling_decode, model):
403
target = sampling_decode
404
the_model = model
405
success = 0
406
fails = 0
407
408
try:
409
output = target("I eat soup.", model, temperature=0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
410
expected = ([161, 15103, 5, 25132, 35, 3, 1], -0.0003108978271484375, 'Ich iss Suppe.')
411
assert(output[2] == expected[2])
412
success += 1
413
except:
414
fails += 1
415
print("Test 1 fails")
416
417
try:
418
output = target("I like your shoes.", model, temperature=0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
419
expected = ([161, 15103, 5, 25132, 35, 3, 1], -0.0003108978271484375, 'Ich mag Ihre Schuhe.')
420
assert(output[2] == expected[2])
421
success += 1
422
except:
423
fails += 1
424
print("Test 2 fails")
425
426
if fails == 0:
427
print("\033[92m All tests passed")
428
else:
429
print('\033[92m', success," Tests passed")
430
print('\033[91m', fails, " Tests failed")
431
432
433
# UNIT TEST for UNQ_C8
434
def test_rouge1_similarity(rouge1_similarity):
435
target = rouge1_similarity
436
success = 0
437
fails = 0
438
n_samples = 10
439
440
test_cases = [
441
442
{
443
"name":"simple_test_check",
444
"input": [[1, 2, 3], [1, 2, 3, 4]],
445
"expected":0.8571428571428571,
446
"error":"Expected similarity: 0.8571428571428571"
447
},
448
{
449
"name":"simple_test_check",
450
"input":[[2, 1], [3, 1]],
451
"expected":0.5,
452
"error":"Expected similarity: 0.5"
453
},
454
{
455
"name":"simple_test_check",
456
"input":[[2], [3]],
457
"expected":0,
458
"error":"Expected similarity: 0"
459
},
460
{
461
"name":"simple_test_check",
462
"input":[[0] * 100 + [2] * 100, [0] * 100 + [1] * 100],
463
"expected":0.5,
464
"error":"Expected similarity: 0.5"
465
}
466
]
467
468
for test_case in test_cases:
469
470
try:
471
if test_case['name'] == "simple_test_check":
472
assert abs(test_case["expected"] -target(*test_case['input'])) < 1e-6
473
success += 1
474
except:
475
print(test_case['error'])
476
fails += 1
477
478
if fails == 0:
479
print("\033[92m All tests passed")
480
else:
481
print('\033[92m', success," Tests passed")
482
print('\033[91m', fails, " Tests failed")
483
484
485
# UNIT TEST for UNQ_C9
486
def test_average_overlap(average_overlap):
487
target = average_overlap
488
success = 0
489
fails = 0
490
491
test_cases = [
492
493
{
494
"name":"dict_test_check",
495
"input": [jaccard_similarity, [[1, 2], [3, 4], [1, 2], [3, 5]]],
496
"expected":{0: 0.3333333333333333,
497
1: 0.1111111111111111,
498
2: 0.3333333333333333,
499
3: 0.1111111111111111},
500
"error":"Expected output does not match"
501
},
502
{
503
"name":"dict_test_check",
504
"input":[jaccard_similarity, [[1, 2], [3, 4], [1, 2, 5], [3, 5], [3, 4, 1]]],
505
"expected":{0: 0.22916666666666666,
506
1: 0.25,
507
2: 0.2791666666666667,
508
3: 0.20833333333333331,
509
4: 0.3416666666666667},
510
"error":"Expected output does not match"
511
}
512
]
513
for test_case in test_cases:
514
try:
515
if test_case['name'] == "dict_test_check":
516
output = target(*test_case['input'])
517
for x in output:
518
assert (abs(output[x] - test_case['expected'][x]) < 1e-5)
519
success += 1
520
except:
521
print(test_case['error'])
522
fails += 1
523
524
if fails == 0:
525
print("\033[92m All tests passed")
526
else:
527
print('\033[92m', success," Tests passed")
528
print('\033[91m', fails, " Tests failed")
529
530
531
# UNIT TEST for UNQ_C10
532
def test_mbr_decode(mbr_decode, model):
533
target = mbr_decode
534
success = 0
535
fails = 0
536
537
TEMPERATURE = 0.0
538
539
test_cases = [
540
541
{
542
"name":"simple_test_check",
543
"input": "I am hungry",
544
"expected":"Ich bin hungrig.",
545
"error":"Expected output does not match"
546
},
547
{
548
"name":"simple_test_check",
549
"input":'Congratulations!',
550
"expected":'Herzlichen Glückwunsch!',
551
"error":"Expected output does not match"
552
},
553
{
554
"name":"simple_test_check",
555
"input":'You have completed the assignment!',
556
"expected":'Sie haben die Abtretung abgeschlossen!',
557
"error":"Expected output does not match"
558
}
559
]
560
for test_case in test_cases:
561
try:
562
result = target(test_case['input'], 4, weighted_avg_overlap, jaccard_similarity,
563
model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
564
565
output = result[0]
566
if test_case['name'] == "simple_test_check":
567
assert(output == test_case['expected'])
568
success += 1
569
except:
570
print(test_case['error'])
571
fails += 1
572
573
# Test that function return the most likely translation
574
TEMPERATURE = 0.5
575
test_case = test_cases[0]
576
try:
577
result = target(test_case['input'], 4, weighted_avg_overlap, jaccard_similarity,
578
model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
579
580
assert max(result[2], key=result[2].get) == result[1]
581
success += 1
582
except:
583
print('Use max function to select max_index')
584
fails += 1
585
586
if fails == 0:
587
print("\033[92m All tests passed")
588
else:
589
print('\033[92m', success," Tests passed")
590
print('\033[91m', fails, " Tests failed")
591