Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/C4 - Convolutional Neural Networks/Week 1/public_tests.py
Views: 4802
import numpy as np1from test_utils import single_test, multiple_test234def zero_pad_test(target):5np.random.seed(1)6x = np.random.randn(4, 3, 3, 2)7pad = 28expected_output = expected_output = np.array([[[[0., 0.],9[0., 0.],10[0., 0.],11[0., 0.],12[0., 0.],13[0., 0.],14[0., 0.]],1516[[0., 0.],17[0., 0.],18[0., 0.],19[0., 0.],20[0., 0.],21[0., 0.],22[0., 0.]],2324[[0., 0.],25[0., 0.],26[1.62434536, -0.61175641],27[-0.52817175, -1.07296862],28[0.86540763, -2.3015387],29[0., 0.],30[0., 0.]],3132[[0., 0.],33[0., 0.],34[1.74481176, -0.7612069],35[0.3190391, -0.24937038],36[1.46210794, -2.06014071],37[0., 0.],38[0., 0.]],3940[[0., 0.],41[0., 0.],42[-0.3224172, -0.38405435],43[1.13376944, -1.09989127],44[-0.17242821, -0.87785842],45[0., 0.],46[0., 0.]],4748[[0., 0.],49[0., 0.],50[0., 0.],51[0., 0.],52[0., 0.],53[0., 0.],54[0., 0.]],5556[[0., 0.],57[0., 0.],58[0., 0.],59[0., 0.],60[0., 0.],61[0., 0.],62[0., 0.]]],636465[[[0., 0.],66[0., 0.],67[0., 0.],68[0., 0.],69[0., 0.],70[0., 0.],71[0., 0.]],7273[[0., 0.],74[0., 0.],75[0., 0.],76[0., 0.],77[0., 0.],78[0., 0.],79[0., 0.]],8081[[0., 0.],82[0., 0.],83[0.04221375, 0.58281521],84[-1.10061918, 1.14472371],85[0.90159072, 0.50249434],86[0., 0.],87[0., 0.]],8889[[0., 0.],90[0., 0.],91[0.90085595, -0.68372786],92[-0.12289023, -0.93576943],93[-0.26788808, 0.53035547],94[0., 0.],95[0., 0.]],9697[[0., 0.],98[0., 0.],99[-0.69166075, -0.39675353],100[-0.6871727, -0.84520564],101[-0.67124613, -0.0126646],102[0., 0.],103[0., 0.]],104105[[0., 0.],106[0., 0.],107[0., 0.],108[0., 0.],109[0., 0.],110[0., 0.],111[0., 0.]],112113[[0., 0.],114[0., 0.],115[0., 0.],116[0., 0.],117[0., 0.],118[0., 0.],119[0., 0.]]],120121122[[[0., 0.],123[0., 0.],124[0., 0.],125[0., 0.],126[0., 0.],127[0., 0.],128[0., 0.]],129130[[0., 0.],131[0., 0.],132[0., 0.],133[0., 0.],134[0., 0.],135[0., 0.],136[0., 0.]],137138[[0., 0.],139[0., 0.],140[-1.11731035, 0.2344157],141[1.65980218, 0.74204416],142[-0.19183555, -0.88762896],143[0., 0.],144[0., 0.]],145146[[0., 0.],147[0., 0.],148[-0.74715829, 1.6924546],149[0.05080775, -0.63699565],150[0.19091548, 2.10025514],151[0., 0.],152[0., 0.]],153154[[0., 0.],155[0., 0.],156[0.12015895, 0.61720311],157[0.30017032, -0.35224985],158[-1.1425182, -0.34934272],159[0., 0.],160[0., 0.]],161162[[0., 0.],163[0., 0.],164[0., 0.],165[0., 0.],166[0., 0.],167[0., 0.],168[0., 0.]],169170[[0., 0.],171[0., 0.],172[0., 0.],173[0., 0.],174[0., 0.],175[0., 0.],176[0., 0.]]],177178179[[[0., 0.],180[0., 0.],181[0., 0.],182[0., 0.],183[0., 0.],184[0., 0.],185[0., 0.]],186187[[0., 0.],188[0., 0.],189[0., 0.],190[0., 0.],191[0., 0.],192[0., 0.],193[0., 0.]],194195[[0., 0.],196[0., 0.],197[-0.20889423, 0.58662319],198[0.83898341, 0.93110208],199[0.28558733, 0.88514116],200[0., 0.],201[0., 0.]],202203[[0., 0.],204[0., 0.],205[-0.75439794, 1.25286816],206[0.51292982, -0.29809284],207[0.48851815, -0.07557171],208[0., 0.],209[0., 0.]],210211[[0., 0.],212[0., 0.],213[1.13162939, 1.51981682],214[2.18557541, -1.39649634],215[-1.44411381, -0.50446586],216[0., 0.],217[0., 0.]],218219[[0., 0.],220[0., 0.],221[0., 0.],222[0., 0.],223[0., 0.],224[0., 0.],225[0., 0.]],226227[[0., 0.],228[0., 0.],229[0., 0.],230[0., 0.],231[0., 0.],232[0., 0.],233[0., 0.]]]])234235test_cases = [236{237"name": "datatype_check",238"input": [x, pad],239"expected": expected_output,240"error":"Datatype mismatch."241},242{243"name": "equation_output_check",244"input": [x, pad],245"expected": expected_output,246"error": "Wrong output"247}248]249250single_test(test_cases, target)251252253def conv_single_step_test(target):254255np.random.seed(1)256a_slice_prev = np.random.randn(4, 4, 3)257W = np.random.randn(4, 4, 3)258b = np.random.randn(1, 1, 1)259expected_output = np.float64(-6.999089450680221)260test_cases = [261{262"name": "datatype_check",263"input": [a_slice_prev, W, b],264"expected": expected_output,265"error":"Datatype mismatch"266},267{268"name": "shape_check",269"input": [a_slice_prev, W, b],270"expected": expected_output,271"error": "Wrong shape"272},273{274"name": "equation_output_check",275"input": [a_slice_prev, W, b],276"expected": expected_output,277"error": "Wrong output"278}279]280281multiple_test(test_cases, target)282283284def conv_forward_test(target):285A_prev = np.random.randn(2, 5, 7, 4)286W = np.random.randn(3, 3, 4, 8)287b = np.random.randn(1, 1, 1, 8)288hparameters = {"pad" : 1,289"stride": 2}290Z, cache_conv = target(A_prev, W, b, {"pad" : 3, "stride": 1})291Z_shape = Z.shape292assert Z_shape[0] == A_prev.shape[0], f"m is wrong. Current: {Z_shape[0]}. Expected: {A_prev.shape[0]}"293assert Z_shape[1] == 9, f"n_H is wrong. Current: {Z_shape[1]}. Expected: 9"294assert Z_shape[2] == 11, f"n_W is wrong. Current: {Z_shape[2]}. Expected: 11"295assert Z_shape[3] == W.shape[3], f"n_C is wrong. Current: {Z_shape[3]}. Expected: {W.shape[3]}"296297Z, cache_conv = target(A_prev, W, b, {"pad" : 0, "stride": 2})298assert(Z.shape == (2, 2, 3, 8)), "Wrong shape. Don't hard code the pad and stride values in the function"299300np.random.seed(1)301A_prev = np.random.randn(2, 5, 7, 4)302W = np.random.randn(3, 3, 4, 8)303b = np.random.randn(1, 1, 1, 8)304hparameters = {"pad": 1,305"stride": 2}306expected_Z = np.array([[[[-2.65112363, -0.37849177, -1.97054929, -1.96235299,307-1.72259872, 0.4676693, -6.43434016, 1.10764994],308[4.67692928, 4.29865415, -1.3608031, 0.80532859,309-2.88480108, 8.95280034, 5.32627807, -1.82635258],310[-2.05881174, 3.40859795, 0.3502282, 0.68303626,311-1.88328065, -1.87480174, 5.8008721, 0.0700918],312[-3.50141791, 2.704286, 0.28341346, 4.15637411,313-0.46575834, -0.43668824, -5.56866106, 1.72288033]],314315[[-2.32126108, 0.91040602, 2.31852532, 0.98842271,3163.31716611, 4.05638832, -2.48135123, 0.95872443],317[6.03978907, -6.96477888, -1.20799344, 2.68913374,318-4.35744033, 10.59355329, 3.20856901, 13.98735978],319[-3.01280755, -2.90226517, -8.34171936, -5.26220853,3205.6630696, 1.08704033, 2.20430705, -10.73218294],321[-6.24198266, -0.53158832, -3.29654954, -1.81865997,3220.59196322, 2.51134745, -4.24924673, 5.21936641]],323324[[-2.22187412, -0.95259173, -5.99441273, 0.79147932,3251.16919278, -0.17321161, -3.26346299, -3.62407578],326[-2.17796037, 8.07171329, -0.5772704, 3.36286738,3274.48113645, -2.89198428, 10.99288867, 3.03171932],328[-12.49991261, 5.26845833, -1.67648614, -8.65695762,329-10.68157258, 6.71492428, 2.83839971, 4.47259772],330[0.11421092, -1.90872424, -3.28117601, 0.89922467,3310.83985348, -0.25127044, -0.94409718, 5.17244412]]],332333334[[[1.97649814, 2.76743075, -6.39611007, 2.95378171,335-0.81235239, -0.53333631, 0.71268871, 4.91385105],336[-5.14401869, 6.97041391, -4.53976469, 5.89092653,337-5.74606931, 2.74256558, 3.02124802, -10.04187592],338[5.53871187, -8.55886701, -4.70962135, 2.55966738,339-2.66959504, 5.60010695, -8.37253342, 4.18848278],340[0.63364517, -3.71848223, -3.67072772, 4.34226476,341-1.21894465, 3.68929452, 5.89166305, 0.94256457]],342343[[2.36049402, -3.09696204, 8.33521755, 3.04680748,3443.7964542, 0.66488788, 1.9935476, 1.54396221],345[-7.73457048, 0.287562, 7.97481218, 3.32415996,346-4.07121488, 2.69182963, 4.1356109, -5.16178423],347[-6.95635186, -0.10924121, -4.12526441, 0.62578199,3484.69492086, -3.52748877, 3.63168271, 0.64007629],349[7.94980014, 5.71855659, 3.49970333, 12.7718152,3508.84959478, 2.37150319, -1.42531648, -0.51126641]],351352[[-5.29658283, -4.20466999, -6.63067766, -9.87831724,353-5.32130395, 7.32417919, 2.96011091, 7.60669481],354[11.54630784, -1.93157244, 2.26699242, 7.62184275,3555.40584348, -2.88837958, -1.46981877, 7.91314719],356[5.94067877, 3.50739649, 0.82512202, 4.80655489,357-4.1044945, 4.14358541, 0.13194885, 4.35397285],358[4.91298364, -1.44499772, 5.9392078, -3.92690408,3592.12840309, 1.27237402, 1.56992581, 0.44270565]]]])360expected_cache = (A_prev, W, b, hparameters)361expected_output = (expected_Z, expected_cache)362test_cases = [363{364"name": "datatype_check",365"input": [A_prev, W, b, hparameters],366"expected": expected_output,367"error":"Datatype mismatch"368},369{370"name": "shape_check",371"input": [A_prev, W, b, hparameters],372"expected": expected_output,373"error": "Wrong shape"374},375{376"name": "equation_output_check",377"input": [A_prev, W, b, hparameters],378"expected": expected_output,379"error": "Wrong output"380}381]382383multiple_test(test_cases, target)384385386def pool_forward_test(target):387388A_prev = np.random.randn(2, 5, 7, 3)389A, cache = target(A_prev, {"stride" : 2, "f": 2}, mode = "average")390A_shape = A.shape391assert A_shape[0] == A_prev.shape[0], f"m is wrong. Current: {A_shape[0]}. Expected: {A_prev.shape[0]}"392assert A_shape[1] == 2, f"n_H is wrong. Current: {A_shape[1]}. Expected: 2"393assert A_shape[2] == 3, f"n_W is wrong. Current: {A_shape[2]}. Expected: 3"394assert A_shape[3] == A_prev.shape[3], f"n_C is wrong. Current: {A_shape[3]}. Expected: {A_prev.shape[3]}"395396np.random.seed(1)397A_prev = np.random.randn(2, 5, 5, 3)398hparameters = {"stride": 1, "f": 3}399expected_cache = (A_prev, hparameters)400401expected_A_max = np.array([[[[1.74481176, 0.90159072, 1.65980218],402[1.74481176, 1.46210794, 1.65980218],403[1.74481176, 1.6924546, 1.65980218]],404405[[1.14472371, 0.90159072, 2.10025514],406[1.14472371, 0.90159072, 1.65980218],407[1.14472371, 1.6924546, 1.65980218]],408409[[1.13162939, 1.51981682, 2.18557541],410[1.13162939, 1.51981682, 2.18557541],411[1.13162939, 1.6924546, 2.18557541]]],412413414[[[1.19891788, 0.84616065, 0.82797464],415[0.69803203, 0.84616065, 1.2245077],416[0.69803203, 1.12141771, 1.2245077]],417418[[1.96710175, 0.84616065, 1.27375593],419[1.96710175, 0.84616065, 1.23616403],420[1.62765075, 1.12141771, 1.2245077]],421422[[1.96710175, 0.86888616, 1.27375593],423[1.96710175, 0.86888616, 1.23616403],424[1.62765075, 1.12141771, 0.79280687]]]])425426expected_output_max = (expected_A_max, expected_cache)427428expected_A_average = np.array([[[[-3.01046719e-02, -3.24021315e-03, -3.36298859e-01],429[1.43310483e-01, 1.93146751e-01, -4304.44905196e-01],431[1.28934436e-01, 2.22428468e-01, 1.25067597e-01]],432433[[-3.81801899e-01, 1.59993515e-02, 1.70562706e-01],434[4.73707165e-02, 2.59244658e-02,4359.20338402e-02],436[3.97048605e-02, 1.57189094e-01, 3.45302489e-01]],437438[[-3.82680519e-01, 2.32579951e-01, 6.25997903e-01],439[-2.47157416e-01, -3.48524998e-04,4403.50539717e-01],441[-9.52551510e-02, 2.68511000e-01, 4.66056368e-01]]],442443444[[[-1.73134159e-01, 3.23771981e-01, -3.43175716e-01],445[3.80634669e-02, 7.26706274e-02, -4462.30268958e-01],447[2.03009393e-02, 1.41414785e-01, -1.23158476e-02]],448449[[4.44976963e-01, -2.61694592e-03, -3.10403073e-01],450[5.08114737e-01, -4512.34937338e-01, -2.39611830e-01],452[1.18726772e-01, 1.72552294e-01, -2.21121966e-01]],453454[[4.29449255e-01, 8.44699612e-02, -2.72909051e-01],455[6.76351685e-01, -4561.20138225e-01, -2.44076712e-01],457[1.50774518e-01, 2.89111751e-01, 1.23238536e-03]]]])458expected_output_average = (expected_A_average, expected_cache)459test_cases = [460{461"name": "datatype_check",462"input": [A_prev, hparameters, "max"],463"expected": expected_output_max,464"error":"Datatype mismatch in MAX-Pool"465},466{467"name": "shape_check",468"input": [A_prev, hparameters, "max"],469"expected": expected_output_max,470"error": "Wrong shape in MAX-Pool"471},472{473"name": "equation_output_check",474"input": [A_prev, hparameters, "max"],475"expected": expected_output_max,476"error": "Wrong output in MAX-Pool"477},478{479"name": "datatype_check",480"input": [A_prev, hparameters, "average"],481"expected": expected_output_average,482"error":"Datatype mismatch in AVG-Pool"483},484{485"name": "shape_check",486"input": [A_prev, hparameters, "average"],487"expected": expected_output_average,488"error": "Wrong shape in AVG-Pool"489},490{491"name": "equation_output_check",492"input": [A_prev, hparameters, "average"],493"expected": expected_output_average,494"error": "Wrong output in AVG-Pool"495}496]497498multiple_test(test_cases, target)499500######################################501############## UNGRADED ##############502######################################503504505def conv_backward_test(target):506507test_cases = [508{509"name": "datatype_check",510"input": [parameters, cache, X, Y],511"expected": expected_output,512"error":"The function should return a numpy array."513},514{515"name": "shape_check",516"input": [parameters, cache, X, Y],517"expected": expected_output,518"error": "Wrong shape"519},520{521"name": "equation_output_check",522"input": [parameters, cache, X, Y],523"expected": expected_output,524"error": "Wrong output"525}526]527528multiple_test(test_cases, target)529530531def create_mask_from_window_test(target):532533test_cases = [534{535"name": "datatype_check",536"input": [parameters, grads],537"expected": expected_output,538"error":"Data type mismatch"539},540{541"name": "shape_check",542"input": [parameters, grads],543"expected": expected_output,544"error": "Wrong shape"545},546{547"name": "equation_output_check",548"input": [parameters, grads],549"expected": expected_output,550"error": "Wrong output"551}552]553554multiple_test(test_cases, target)555556557def distribute_value_test(target):558test_cases = [559{560"name": "datatype_check",561"input": [X, Y, n_h],562"expected": expected_output,563"error":"Data type mismatch"564},565{566"name": "shape_check",567"input": [X, Y, n_h],568"expected": expected_output,569"error": "Wrong shape"570},571{572"name": "equation_output_check",573"input": [X, Y, n_h],574"expected": expected_output,575"error": "Wrong output"576}577]578579multiple_test(test_cases, target)580581582def pool_backward_test(target):583584test_cases = [585{586"name": "datatype_check",587"input": [parameters, X],588"expected": expected_output,589"error":"Data type mismatch"590},591{592"name": "shape_check",593"input": [parameters, X],594"expected": expected_output,595"error": "Wrong shape"596},597{598"name": "equation_output_check",599"input": [parameters, X],600"expected": expected_output,601"error": "Wrong output"602}603]604605single_test(test_cases, target)606607608