Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/C4 - Convolutional Neural Networks/Week 3/Image Segmentation Unet/test_utils.py
Views: 4818
import numpy as np1from termcolor import colored23from tensorflow.keras.layers import Input4from tensorflow.keras.layers import Conv2D5from tensorflow.keras.layers import MaxPooling2D6from tensorflow.keras.layers import Dropout7from tensorflow.keras.layers import Conv2DTranspose8from tensorflow.keras.layers import concatenate910# Compare the two inputs11def comparator(learner, instructor):12for a, b in zip(learner, instructor):13if tuple(a) != tuple(b):14print(colored("Test failed", attrs=['bold']),15"\n Expected value \n\n", colored(f"{b}", "green"),16"\n\n does not match the input value: \n\n",17colored(f"{a}", "red"))18raise AssertionError("Error in test")19print(colored("All tests passed!", "green"))2021# extracts the description of a given model22def summary(model):23model.compile(optimizer='adam',24loss='categorical_crossentropy',25metrics=['accuracy'])26result = []27for layer in model.layers:28descriptors = [layer.__class__.__name__, layer.output_shape, layer.count_params()]29if (type(layer) == Conv2D):30descriptors.append(layer.padding)31descriptors.append(layer.activation.__name__)32descriptors.append(layer.kernel_initializer.__class__.__name__)33if (type(layer) == MaxPooling2D):34descriptors.append(layer.pool_size)35if (type(layer) == Dropout):36descriptors.append(layer.rate)37result.append(descriptors)38return result3940def datatype_check(expected_output, target_output, error):41success = 042if isinstance(target_output, dict):43for key in target_output.keys():44try:45success += datatype_check(expected_output[key],46target_output[key], error)47except:48print("Error: {} in variable {}. Got {} but expected type {}".format(error,49key, type(target_output[key]), type(expected_output[key])))50if success == len(target_output.keys()):51return 152else:53return 054elif isinstance(target_output, tuple) or isinstance(target_output, list):55for i in range(len(target_output)):56try:57success += datatype_check(expected_output[i],58target_output[i], error)59except:60print("Error: {} in variable {}, expected type: {} but expected type {}".format(error,61i, type(target_output[i]), type(expected_output[i])))62if success == len(target_output):63return 164else:65return 06667else:68assert isinstance(target_output, type(expected_output))69return 17071def equation_output_check(expected_output, target_output, error):72success = 073if isinstance(target_output, dict):74for key in target_output.keys():75try:76success += equation_output_check(expected_output[key],77target_output[key], error)78except:79print("Error: {} for variable {}.".format(error,80key))81if success == len(target_output.keys()):82return 183else:84return 085elif isinstance(target_output, tuple) or isinstance(target_output, list):86for i in range(len(target_output)):87try:88success += equation_output_check(expected_output[i],89target_output[i], error)90except:91print("Error: {} for variable in position {}.".format(error, i))92if success == len(target_output):93return 194else:95return 09697else:98if hasattr(target_output, 'shape'):99np.testing.assert_array_almost_equal(target_output, expected_output)100else:101assert target_output == expected_output102return 1103104def shape_check(expected_output, target_output, error):105success = 0106if isinstance(target_output, dict):107for key in target_output.keys():108try:109success += shape_check(expected_output[key],110target_output[key], error)111except:112print("Error: {} for variable {}.".format(error, key))113if success == len(target_output.keys()):114return 1115else:116return 0117elif isinstance(target_output, tuple) or isinstance(target_output, list):118for i in range(len(target_output)):119try:120success += shape_check(expected_output[i],121target_output[i], error)122except:123print("Error: {} for variable {}.".format(error, i))124if success == len(target_output):125return 1126else:127return 0128129else:130if hasattr(target_output, 'shape'):131assert target_output.shape == expected_output.shape132return 1133134def single_test(test_cases, target):135success = 0136for test_case in test_cases:137try:138if test_case['name'] == "datatype_check":139assert isinstance(target(*test_case['input']),140type(test_case["expected"]))141success += 1142if test_case['name'] == "equation_output_check":143assert np.allclose(test_case["expected"],144target(*test_case['input']))145success += 1146if test_case['name'] == "shape_check":147assert test_case['expected'].shape == target(*test_case['input']).shape148success += 1149except:150print("Error: " + test_case['error'])151152if success == len(test_cases):153print("\033[92m All tests passed.")154else:155print('\033[92m', success," Tests passed")156print('\033[91m', len(test_cases) - success, " Tests failed")157raise AssertionError("Not all tests were passed for {}. Check your equations and avoid using global variables inside the function.".format(target.__name__))158159def multiple_test(test_cases, target):160success = 0161for test_case in test_cases:162try:163target_answer = target(*test_case['input'])164if test_case['name'] == "datatype_check":165success += datatype_check(test_case['expected'], target_answer, test_case['error'])166if test_case['name'] == "equation_output_check":167success += equation_output_check(test_case['expected'], target_answer, test_case['error'])168if test_case['name'] == "shape_check":169success += shape_check(test_case['expected'], target_answer, test_case['error'])170except:171print("Error: " + test_case['error'])172173if success == len(test_cases):174print("\033[92m All tests passed.")175else:176print('\033[92m', success," Tests passed")177print('\033[91m', len(test_cases) - success, " Tests failed")178raise AssertionError("Not all tests were passed for {}. Check your equations and avoid using global variables inside the function.".format(target.__name__))179180181182183184