CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
amanchadha

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: amanchadha/coursera-deep-learning-specialization
Path: blob/master/C4 - Convolutional Neural Networks/Week 3/Image Segmentation Unet/test_utils.py
Views: 4818
1
import numpy as np
2
from termcolor import colored
3
4
from tensorflow.keras.layers import Input
5
from tensorflow.keras.layers import Conv2D
6
from tensorflow.keras.layers import MaxPooling2D
7
from tensorflow.keras.layers import Dropout
8
from tensorflow.keras.layers import Conv2DTranspose
9
from tensorflow.keras.layers import concatenate
10
11
# Compare the two inputs
12
def comparator(learner, instructor):
13
for a, b in zip(learner, instructor):
14
if tuple(a) != tuple(b):
15
print(colored("Test failed", attrs=['bold']),
16
"\n Expected value \n\n", colored(f"{b}", "green"),
17
"\n\n does not match the input value: \n\n",
18
colored(f"{a}", "red"))
19
raise AssertionError("Error in test")
20
print(colored("All tests passed!", "green"))
21
22
# extracts the description of a given model
23
def summary(model):
24
model.compile(optimizer='adam',
25
loss='categorical_crossentropy',
26
metrics=['accuracy'])
27
result = []
28
for layer in model.layers:
29
descriptors = [layer.__class__.__name__, layer.output_shape, layer.count_params()]
30
if (type(layer) == Conv2D):
31
descriptors.append(layer.padding)
32
descriptors.append(layer.activation.__name__)
33
descriptors.append(layer.kernel_initializer.__class__.__name__)
34
if (type(layer) == MaxPooling2D):
35
descriptors.append(layer.pool_size)
36
if (type(layer) == Dropout):
37
descriptors.append(layer.rate)
38
result.append(descriptors)
39
return result
40
41
def datatype_check(expected_output, target_output, error):
42
success = 0
43
if isinstance(target_output, dict):
44
for key in target_output.keys():
45
try:
46
success += datatype_check(expected_output[key],
47
target_output[key], error)
48
except:
49
print("Error: {} in variable {}. Got {} but expected type {}".format(error,
50
key, type(target_output[key]), type(expected_output[key])))
51
if success == len(target_output.keys()):
52
return 1
53
else:
54
return 0
55
elif isinstance(target_output, tuple) or isinstance(target_output, list):
56
for i in range(len(target_output)):
57
try:
58
success += datatype_check(expected_output[i],
59
target_output[i], error)
60
except:
61
print("Error: {} in variable {}, expected type: {} but expected type {}".format(error,
62
i, type(target_output[i]), type(expected_output[i])))
63
if success == len(target_output):
64
return 1
65
else:
66
return 0
67
68
else:
69
assert isinstance(target_output, type(expected_output))
70
return 1
71
72
def equation_output_check(expected_output, target_output, error):
73
success = 0
74
if isinstance(target_output, dict):
75
for key in target_output.keys():
76
try:
77
success += equation_output_check(expected_output[key],
78
target_output[key], error)
79
except:
80
print("Error: {} for variable {}.".format(error,
81
key))
82
if success == len(target_output.keys()):
83
return 1
84
else:
85
return 0
86
elif isinstance(target_output, tuple) or isinstance(target_output, list):
87
for i in range(len(target_output)):
88
try:
89
success += equation_output_check(expected_output[i],
90
target_output[i], error)
91
except:
92
print("Error: {} for variable in position {}.".format(error, i))
93
if success == len(target_output):
94
return 1
95
else:
96
return 0
97
98
else:
99
if hasattr(target_output, 'shape'):
100
np.testing.assert_array_almost_equal(target_output, expected_output)
101
else:
102
assert target_output == expected_output
103
return 1
104
105
def shape_check(expected_output, target_output, error):
106
success = 0
107
if isinstance(target_output, dict):
108
for key in target_output.keys():
109
try:
110
success += shape_check(expected_output[key],
111
target_output[key], error)
112
except:
113
print("Error: {} for variable {}.".format(error, key))
114
if success == len(target_output.keys()):
115
return 1
116
else:
117
return 0
118
elif isinstance(target_output, tuple) or isinstance(target_output, list):
119
for i in range(len(target_output)):
120
try:
121
success += shape_check(expected_output[i],
122
target_output[i], error)
123
except:
124
print("Error: {} for variable {}.".format(error, i))
125
if success == len(target_output):
126
return 1
127
else:
128
return 0
129
130
else:
131
if hasattr(target_output, 'shape'):
132
assert target_output.shape == expected_output.shape
133
return 1
134
135
def single_test(test_cases, target):
136
success = 0
137
for test_case in test_cases:
138
try:
139
if test_case['name'] == "datatype_check":
140
assert isinstance(target(*test_case['input']),
141
type(test_case["expected"]))
142
success += 1
143
if test_case['name'] == "equation_output_check":
144
assert np.allclose(test_case["expected"],
145
target(*test_case['input']))
146
success += 1
147
if test_case['name'] == "shape_check":
148
assert test_case['expected'].shape == target(*test_case['input']).shape
149
success += 1
150
except:
151
print("Error: " + test_case['error'])
152
153
if success == len(test_cases):
154
print("\033[92m All tests passed.")
155
else:
156
print('\033[92m', success," Tests passed")
157
print('\033[91m', len(test_cases) - success, " Tests failed")
158
raise AssertionError("Not all tests were passed for {}. Check your equations and avoid using global variables inside the function.".format(target.__name__))
159
160
def multiple_test(test_cases, target):
161
success = 0
162
for test_case in test_cases:
163
try:
164
target_answer = target(*test_case['input'])
165
if test_case['name'] == "datatype_check":
166
success += datatype_check(test_case['expected'], target_answer, test_case['error'])
167
if test_case['name'] == "equation_output_check":
168
success += equation_output_check(test_case['expected'], target_answer, test_case['error'])
169
if test_case['name'] == "shape_check":
170
success += shape_check(test_case['expected'], target_answer, test_case['error'])
171
except:
172
print("Error: " + test_case['error'])
173
174
if success == len(test_cases):
175
print("\033[92m All tests passed.")
176
else:
177
print('\033[92m', success," Tests passed")
178
print('\033[91m', len(test_cases) - success, " Tests failed")
179
raise AssertionError("Not all tests were passed for {}. Check your equations and avoid using global variables inside the function.".format(target.__name__))
180
181
182
183
184