Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/C5 - Sequence Models/Week 1/Building a Recurrent Neural Network - Step by Step/public_tests.py
Views: 4819
import numpy as np1from rnn_utils import *23def rnn_cell_forward_tests(target):4# Only bias in expression5a_prev_tmp = np.zeros((5, 10))6xt_tmp = np.zeros((3, 10))7parameters_tmp = {}8parameters_tmp['Waa'] = np.random.randn(5, 5)9parameters_tmp['Wax'] = np.random.randn(5, 3)10parameters_tmp['Wya'] = np.random.randn(2, 5)11parameters_tmp['ba'] = np.random.randn(5, 1)12parameters_tmp['by'] = np.random.randn(2, 1)13parameters_tmp['Wya'] = np.zeros((2, 5))1415a_next_tmp, yt_pred_tmp, cache_tmp = target(xt_tmp, a_prev_tmp, parameters_tmp)1617assert a_next_tmp.shape == (5, 10), f"Wrong shape for a_next. Expected (5, 10) != {a_next_tmp.shape}"18assert yt_pred_tmp.shape == (2, 10), f"Wrong shape for yt_pred. Expected (2, 10) != {yt_pred_tmp.shape}"19assert cache_tmp[0].shape == (5, 10), "Wrong shape in cache->a_next"20assert cache_tmp[1].shape == (5, 10), "Wrong shape in cache->a_prev"21assert cache_tmp[2].shape == (3, 10), "Wrong shape in cache->x_t"22assert len(cache_tmp[3].keys()) == 5, "Wrong number of parameters in cache. Expected 5"2324assert np.allclose(np.tanh(parameters_tmp['ba']), a_next_tmp), "Problem 1 in a_next expression. Related to ba?"25assert np.allclose(softmax(parameters_tmp['by']), yt_pred_tmp), "Problem 1 in yt_pred expression. Related to by?"2627# Only xt in expression28a_prev_tmp = np.zeros((5,10))29xt_tmp = np.random.randn(3,10)30parameters_tmp['Wax'] = np.random.randn(5,3)31parameters_tmp['ba'] = np.zeros((5,1))32parameters_tmp['by'] = np.zeros((2,1))3334a_next_tmp, yt_pred_tmp, cache_tmp = target(xt_tmp, a_prev_tmp, parameters_tmp)3536assert np.allclose(np.tanh(np.dot(parameters_tmp['Wax'], xt_tmp)), a_next_tmp), "Problem 2 in a_next expression. Related to xt?"37assert np.allclose(softmax(np.dot(parameters_tmp['Wya'], a_next_tmp)), yt_pred_tmp), "Problem 2 in yt_pred expression. Related to a_next?"3839# Only a_prev in expression40a_prev_tmp = np.random.randn(5,10)41xt_tmp = np.zeros((3,10))42parameters_tmp['Waa'] = np.random.randn(5,5)43parameters_tmp['ba'] = np.zeros((5,1))44parameters_tmp['by'] = np.zeros((2,1))4546a_next_tmp, yt_pred_tmp, cache_tmp = target(xt_tmp, a_prev_tmp, parameters_tmp)4748assert np.allclose(np.tanh(np.dot(parameters_tmp['Waa'], a_prev_tmp)), a_next_tmp), "Problem 3 in a_next expression. Related to a_prev?"49assert np.allclose(softmax(np.dot(parameters_tmp['Wya'], a_next_tmp)), yt_pred_tmp), "Problem 3 in yt_pred expression. Related to a_next?"5051print("\033[92mAll tests passed")525354def rnn_forward_test(target):55np.random.seed(17)56T_x = 1357m = 858n_x = 459n_a = 760n_y = 361x_tmp = np.random.randn(n_x, m, T_x)62a0_tmp = np.random.randn(n_a, m)63parameters_tmp = {}64parameters_tmp['Waa'] = np.random.randn(n_a, n_a)65parameters_tmp['Wax'] = np.random.randn(n_a, n_x)66parameters_tmp['Wya'] = np.random.randn(n_y, n_a)67parameters_tmp['ba'] = np.random.randn(n_a, 1)68parameters_tmp['by'] = np.random.randn(n_y, 1)6970a, y_pred, caches = target(x_tmp, a0_tmp, parameters_tmp)7172assert a.shape == (n_a, m, T_x), f"Wrong shape for a. Expected: ({n_a, m, T_x}) != {a.shape}"73assert y_pred.shape == (n_y, m, T_x), f"Wrong shape for y_pred. Expected: ({n_y, m, T_x}) != {y_pred.shape}"74assert len(caches[0]) == T_x, f"len(cache) must be T_x = {T_x}"7576assert np.allclose(a[5, 2, 2:6], [0.99999291, 0.99332189, 0.9921928, 0.99503445]), "Wrong values for a"77assert np.allclose(y_pred[2, 1, 1: 5], [0.19428, 0.14292, 0.24993, 0.00119], atol=1e-4), "Wrong values for y_pred"78assert np.allclose(caches[1], x_tmp), f"Fail check: cache[1] != x_tmp"798081print("\033[92mAll tests passed")8283def lstm_cell_forward_test(target):84np.random.seed(212)85m = 886n_x = 487n_a = 788n_y = 389x = np.random.randn(n_x, m)90a0 = np.random.randn(n_a, m)91c0 = np.random.randn(n_a, m)92params = {}93params['Wf'] = np.random.randn(n_a, n_a + n_x)94params['bf'] = np.random.randn(n_a, 1)95params['Wi'] = np.random.randn(n_a, n_a + n_x)96params['bi'] = np.random.randn(n_a, 1)97params['Wo'] = np.random.randn(n_a, n_a + n_x)98params['bo'] = np.random.randn(n_a, 1)99params['Wc'] = np.random.randn(n_a, n_a + n_x)100params['bc'] = np.random.randn(n_a, 1)101params['Wy'] = np.random.randn(n_y, n_a)102params['by'] = np.random.randn(n_y, 1)103a_next, c_next, y_pred, cache = target(x, a0, c0, params)104105assert len(cache) == 10, "Don't change the cache"106107assert cache[4].shape == (n_a, m), f"Wrong shape for cache[4](ft). {cache[4].shape} != {(n_a, m)}"108assert cache[5].shape == (n_a, m), f"Wrong shape for cache[5](it). {cache[5].shape} != {(n_a, m)}"109assert cache[6].shape == (n_a, m), f"Wrong shape for cache[6](cct). {cache[6].shape} != {(n_a, m)}"110assert cache[1].shape == (n_a, m), f"Wrong shape for cache[1](c_next). {cache[1].shape} != {(n_a, m)}"111assert cache[7].shape == (n_a, m), f"Wrong shape for cache[7](ot). {cache[7].shape} != {(n_a, m)}"112assert cache[0].shape == (n_a, m), f"Wrong shape for cache[0](a_next). {cache[0].shape} != {(n_a, m)}"113assert cache[8].shape == (n_x, m), f"Wrong shape for cache[8](xt). {cache[8].shape} != {(n_x, m)}"114assert cache[2].shape == (n_a, m), f"Wrong shape for cache[2](a_prev). {cache[2].shape} != {(n_a, m)}"115assert cache[3].shape == (n_a, m), f"Wrong shape for cache[3](c_prev). {cache[3].shape} != {(n_a, m)}"116117assert a_next.shape == (n_a, m), f"Wrong shape for a_next. {a_next.shape} != {(n_a, m)}"118assert c_next.shape == (n_a, m), f"Wrong shape for c_next. {c_next.shape} != {(n_a, m)}"119assert y_pred.shape == (n_y, m), f"Wrong shape for y_pred. {y_pred.shape} != {(n_y, m)}"120121122assert np.allclose(cache[4][0, 0:2], [0.32969833, 0.0574555]), "wrong values for ft"123assert np.allclose(cache[5][0, 0:2], [0.0036446, 0.9806943]), "wrong values for it"124assert np.allclose(cache[6][0, 0:2], [0.99903873, 0.57509956]), "wrong values for cct"125assert np.allclose(cache[1][0, 0:2], [0.1352798, 0.39884899]), "wrong values for c_next"126assert np.allclose(cache[7][0, 0:2], [0.7477249, 0.71588751]), "wrong values for ot"127assert np.allclose(cache[0][0, 0:2], [0.10053951, 0.27129536]), "wrong values for a_next"128129assert np.allclose(y_pred[1], [0.417098, 0.449528, 0.223159, 0.278376,1300.68453, 0.419221, 0.564025, 0.538475]), "Wrong values for y_pred"131132print("\033[92mAll tests passed")133134def lstm_forward_test(target):135np.random.seed(45)136n_x = 4137m = 13138T_x = 16139n_a = 3140n_y = 2141x_tmp = np.random.randn(n_x, m, T_x)142a0_tmp = np.random.randn(n_a, m)143parameters_tmp = {}144parameters_tmp['Wf'] = np.random.randn(n_a, n_a + n_x)145parameters_tmp['bf'] = np.random.randn(n_a, 1)146parameters_tmp['Wi'] = np.random.randn(n_a, n_a + n_x)147parameters_tmp['bi']= np.random.randn(n_a, 1)148parameters_tmp['Wo'] = np.random.randn(n_a, n_a + n_x)149parameters_tmp['bo'] = np.random.randn(n_a, 1)150parameters_tmp['Wc'] = np.random.randn(n_a, n_a + n_x)151parameters_tmp['bc'] = np.random.randn(n_a, 1)152parameters_tmp['Wy'] = np.random.randn(n_y, n_a)153parameters_tmp['by'] = np.random.randn(n_y, 1)154155a, y, c, caches = target(x_tmp, a0_tmp, parameters_tmp)156157assert a.shape == (n_a, m, T_x), f"Wrong shape for a. {a.shape} != {(n_a, m, T_x)}"158assert c.shape == (n_a, m, T_x), f"Wrong shape for c. {c.shape} != {(n_a, m, T_x)}"159assert y.shape == (n_y, m, T_x), f"Wrong shape for y. {y.shape} != {(n_y, m, T_x)}"160assert len(caches[0]) == T_x, f"Wrong shape for caches. {len(caches[0])} != {T_x} "161assert len(caches[0][0]) == 10, f"length of caches[0][0] must be 10."162163assert np.allclose(a[2, 1, 4:6], [-0.01606022, 0.0243569]), "Wrong values for a"164assert np.allclose(c[2, 1, 4:6], [-0.02753855, 0.05668358]), "Wrong values for c"165assert np.allclose(y[1, 1, 4:6], [0.70444592 ,0.70648935]), "Wrong values for y"166167print("\033[92mAll tests passed")168169