Path: blob/main/C1 - Supervised Machine Learning - Regression and Classification/week3/C1W3A1/test_utils.py
3748 views
import numpy as np1from copy import deepcopy234def datatype_check(expected_output, target_output, error):5success = 06if isinstance(target_output, dict):7for key in target_output.keys():8try:9success += datatype_check(expected_output[key],10target_output[key], error)11except:12print("Error: {} in variable {}. Got {} but expected type {}".format(error,13key,14type(15target_output[key]),16type(expected_output[key])))17if success == len(target_output.keys()):18return 119else:20return 021elif isinstance(target_output, tuple) or isinstance(target_output, list):22for i in range(len(target_output)):23try:24success += datatype_check(expected_output[i],25target_output[i], error)26except:27print("Error: {} in variable {}, expected type: {} but expected type {}".format(error,28i,29type(30target_output[i]),31type(expected_output[i]32)))33if success == len(target_output):34return 135else:36return 03738else:39assert isinstance(target_output, type(expected_output))40return 1414243def equation_output_check(expected_output, target_output, error):44success = 045if isinstance(target_output, dict):46for key in target_output.keys():47try:48success += equation_output_check(expected_output[key],49target_output[key], error)50except:51print("Error: {} for variable {}.".format(error,52key))53if success == len(target_output.keys()):54return 155else:56return 057elif isinstance(target_output, tuple) or isinstance(target_output, list):58for i in range(len(target_output)):59try:60success += equation_output_check(expected_output[i],61target_output[i], error)62except:63print("Error: {} for variable in position {}.".format(error, i))64if success == len(target_output):65return 166else:67return 06869else:70if hasattr(target_output, 'shape'):71np.testing.assert_array_almost_equal(72target_output, expected_output)73else:74assert target_output == expected_output75return 1767778def shape_check(expected_output, target_output, error):79success = 080if isinstance(target_output, dict):81for key in target_output.keys():82try:83success += shape_check(expected_output[key],84target_output[key], error)85except:86print("Error: {} for variable {}.".format(error, key))87if success == len(target_output.keys()):88return 189else:90return 091elif isinstance(target_output, tuple) or isinstance(target_output, list):92for i in range(len(target_output)):93try:94success += shape_check(expected_output[i],95target_output[i], error)96except:97print("Error: {} for variable {}.".format(error, i))98if success == len(target_output):99return 1100else:101return 0102103else:104if hasattr(target_output, 'shape'):105assert target_output.shape == expected_output.shape106return 1107108109def single_test(test_cases, target):110success = 0111for test_case in test_cases:112try:113if test_case['name'] == "datatype_check":114assert isinstance(target(*test_case['input']),115type(test_case["expected"]))116success += 1117if test_case['name'] == "equation_output_check":118assert np.allclose(test_case["expected"],119target(*test_case['input']))120success += 1121if test_case['name'] == "shape_check":122assert test_case['expected'].shape == target(123*test_case['input']).shape124success += 1125except:126print("Error: " + test_case['error'])127128if success == len(test_cases):129print("\033[92m All tests passed.")130else:131print('\033[92m', success, " Tests passed")132print('\033[91m', len(test_cases) - success, " Tests failed")133raise AssertionError(134"Not all tests were passed for {}. Check your equations and avoid using global variables inside the function.".format(target.__name__))135136137def multiple_test(test_cases, target):138success = 0139for test_case in test_cases:140try:141test_input = deepcopy(test_case['input'])142target_answer = target(*test_input)143if test_case['name'] == "datatype_check":144success += datatype_check(test_case['expected'],145target_answer, test_case['error'])146if test_case['name'] == "equation_output_check":147success += equation_output_check(148test_case['expected'], target_answer, test_case['error'])149if test_case['name'] == "shape_check":150success += shape_check(test_case['expected'],151target_answer, test_case['error'])152except:153print('\33[30m', "Error: " + test_case['error'])154155if success == len(test_cases):156print("\033[92m All tests passed.")157else:158print('\033[92m', success, " Tests passed")159print('\033[91m', len(test_cases) - success, " Tests failed")160raise AssertionError(161"Not all tests were passed for {}. Check your equations and avoid using global variables inside the function.".format(target.__name__))162163164165