CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
amanchadha

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: amanchadha/coursera-deep-learning-specialization
Path: blob/master/C5 - Sequence Models/Week 1/Building a Recurrent Neural Network - Step by Step/public_tests.py
Views: 4819
1
import numpy as np
2
from rnn_utils import *
3
4
def rnn_cell_forward_tests(target):
5
# Only bias in expression
6
a_prev_tmp = np.zeros((5, 10))
7
xt_tmp = np.zeros((3, 10))
8
parameters_tmp = {}
9
parameters_tmp['Waa'] = np.random.randn(5, 5)
10
parameters_tmp['Wax'] = np.random.randn(5, 3)
11
parameters_tmp['Wya'] = np.random.randn(2, 5)
12
parameters_tmp['ba'] = np.random.randn(5, 1)
13
parameters_tmp['by'] = np.random.randn(2, 1)
14
parameters_tmp['Wya'] = np.zeros((2, 5))
15
16
a_next_tmp, yt_pred_tmp, cache_tmp = target(xt_tmp, a_prev_tmp, parameters_tmp)
17
18
assert a_next_tmp.shape == (5, 10), f"Wrong shape for a_next. Expected (5, 10) != {a_next_tmp.shape}"
19
assert yt_pred_tmp.shape == (2, 10), f"Wrong shape for yt_pred. Expected (2, 10) != {yt_pred_tmp.shape}"
20
assert cache_tmp[0].shape == (5, 10), "Wrong shape in cache->a_next"
21
assert cache_tmp[1].shape == (5, 10), "Wrong shape in cache->a_prev"
22
assert cache_tmp[2].shape == (3, 10), "Wrong shape in cache->x_t"
23
assert len(cache_tmp[3].keys()) == 5, "Wrong number of parameters in cache. Expected 5"
24
25
assert np.allclose(np.tanh(parameters_tmp['ba']), a_next_tmp), "Problem 1 in a_next expression. Related to ba?"
26
assert np.allclose(softmax(parameters_tmp['by']), yt_pred_tmp), "Problem 1 in yt_pred expression. Related to by?"
27
28
# Only xt in expression
29
a_prev_tmp = np.zeros((5,10))
30
xt_tmp = np.random.randn(3,10)
31
parameters_tmp['Wax'] = np.random.randn(5,3)
32
parameters_tmp['ba'] = np.zeros((5,1))
33
parameters_tmp['by'] = np.zeros((2,1))
34
35
a_next_tmp, yt_pred_tmp, cache_tmp = target(xt_tmp, a_prev_tmp, parameters_tmp)
36
37
assert np.allclose(np.tanh(np.dot(parameters_tmp['Wax'], xt_tmp)), a_next_tmp), "Problem 2 in a_next expression. Related to xt?"
38
assert np.allclose(softmax(np.dot(parameters_tmp['Wya'], a_next_tmp)), yt_pred_tmp), "Problem 2 in yt_pred expression. Related to a_next?"
39
40
# Only a_prev in expression
41
a_prev_tmp = np.random.randn(5,10)
42
xt_tmp = np.zeros((3,10))
43
parameters_tmp['Waa'] = np.random.randn(5,5)
44
parameters_tmp['ba'] = np.zeros((5,1))
45
parameters_tmp['by'] = np.zeros((2,1))
46
47
a_next_tmp, yt_pred_tmp, cache_tmp = target(xt_tmp, a_prev_tmp, parameters_tmp)
48
49
assert np.allclose(np.tanh(np.dot(parameters_tmp['Waa'], a_prev_tmp)), a_next_tmp), "Problem 3 in a_next expression. Related to a_prev?"
50
assert np.allclose(softmax(np.dot(parameters_tmp['Wya'], a_next_tmp)), yt_pred_tmp), "Problem 3 in yt_pred expression. Related to a_next?"
51
52
print("\033[92mAll tests passed")
53
54
55
def rnn_forward_test(target):
56
np.random.seed(17)
57
T_x = 13
58
m = 8
59
n_x = 4
60
n_a = 7
61
n_y = 3
62
x_tmp = np.random.randn(n_x, m, T_x)
63
a0_tmp = np.random.randn(n_a, m)
64
parameters_tmp = {}
65
parameters_tmp['Waa'] = np.random.randn(n_a, n_a)
66
parameters_tmp['Wax'] = np.random.randn(n_a, n_x)
67
parameters_tmp['Wya'] = np.random.randn(n_y, n_a)
68
parameters_tmp['ba'] = np.random.randn(n_a, 1)
69
parameters_tmp['by'] = np.random.randn(n_y, 1)
70
71
a, y_pred, caches = target(x_tmp, a0_tmp, parameters_tmp)
72
73
assert a.shape == (n_a, m, T_x), f"Wrong shape for a. Expected: ({n_a, m, T_x}) != {a.shape}"
74
assert y_pred.shape == (n_y, m, T_x), f"Wrong shape for y_pred. Expected: ({n_y, m, T_x}) != {y_pred.shape}"
75
assert len(caches[0]) == T_x, f"len(cache) must be T_x = {T_x}"
76
77
assert np.allclose(a[5, 2, 2:6], [0.99999291, 0.99332189, 0.9921928, 0.99503445]), "Wrong values for a"
78
assert np.allclose(y_pred[2, 1, 1: 5], [0.19428, 0.14292, 0.24993, 0.00119], atol=1e-4), "Wrong values for y_pred"
79
assert np.allclose(caches[1], x_tmp), f"Fail check: cache[1] != x_tmp"
80
81
82
print("\033[92mAll tests passed")
83
84
def lstm_cell_forward_test(target):
85
np.random.seed(212)
86
m = 8
87
n_x = 4
88
n_a = 7
89
n_y = 3
90
x = np.random.randn(n_x, m)
91
a0 = np.random.randn(n_a, m)
92
c0 = np.random.randn(n_a, m)
93
params = {}
94
params['Wf'] = np.random.randn(n_a, n_a + n_x)
95
params['bf'] = np.random.randn(n_a, 1)
96
params['Wi'] = np.random.randn(n_a, n_a + n_x)
97
params['bi'] = np.random.randn(n_a, 1)
98
params['Wo'] = np.random.randn(n_a, n_a + n_x)
99
params['bo'] = np.random.randn(n_a, 1)
100
params['Wc'] = np.random.randn(n_a, n_a + n_x)
101
params['bc'] = np.random.randn(n_a, 1)
102
params['Wy'] = np.random.randn(n_y, n_a)
103
params['by'] = np.random.randn(n_y, 1)
104
a_next, c_next, y_pred, cache = target(x, a0, c0, params)
105
106
assert len(cache) == 10, "Don't change the cache"
107
108
assert cache[4].shape == (n_a, m), f"Wrong shape for cache[4](ft). {cache[4].shape} != {(n_a, m)}"
109
assert cache[5].shape == (n_a, m), f"Wrong shape for cache[5](it). {cache[5].shape} != {(n_a, m)}"
110
assert cache[6].shape == (n_a, m), f"Wrong shape for cache[6](cct). {cache[6].shape} != {(n_a, m)}"
111
assert cache[1].shape == (n_a, m), f"Wrong shape for cache[1](c_next). {cache[1].shape} != {(n_a, m)}"
112
assert cache[7].shape == (n_a, m), f"Wrong shape for cache[7](ot). {cache[7].shape} != {(n_a, m)}"
113
assert cache[0].shape == (n_a, m), f"Wrong shape for cache[0](a_next). {cache[0].shape} != {(n_a, m)}"
114
assert cache[8].shape == (n_x, m), f"Wrong shape for cache[8](xt). {cache[8].shape} != {(n_x, m)}"
115
assert cache[2].shape == (n_a, m), f"Wrong shape for cache[2](a_prev). {cache[2].shape} != {(n_a, m)}"
116
assert cache[3].shape == (n_a, m), f"Wrong shape for cache[3](c_prev). {cache[3].shape} != {(n_a, m)}"
117
118
assert a_next.shape == (n_a, m), f"Wrong shape for a_next. {a_next.shape} != {(n_a, m)}"
119
assert c_next.shape == (n_a, m), f"Wrong shape for c_next. {c_next.shape} != {(n_a, m)}"
120
assert y_pred.shape == (n_y, m), f"Wrong shape for y_pred. {y_pred.shape} != {(n_y, m)}"
121
122
123
assert np.allclose(cache[4][0, 0:2], [0.32969833, 0.0574555]), "wrong values for ft"
124
assert np.allclose(cache[5][0, 0:2], [0.0036446, 0.9806943]), "wrong values for it"
125
assert np.allclose(cache[6][0, 0:2], [0.99903873, 0.57509956]), "wrong values for cct"
126
assert np.allclose(cache[1][0, 0:2], [0.1352798, 0.39884899]), "wrong values for c_next"
127
assert np.allclose(cache[7][0, 0:2], [0.7477249, 0.71588751]), "wrong values for ot"
128
assert np.allclose(cache[0][0, 0:2], [0.10053951, 0.27129536]), "wrong values for a_next"
129
130
assert np.allclose(y_pred[1], [0.417098, 0.449528, 0.223159, 0.278376,
131
0.68453, 0.419221, 0.564025, 0.538475]), "Wrong values for y_pred"
132
133
print("\033[92mAll tests passed")
134
135
def lstm_forward_test(target):
136
np.random.seed(45)
137
n_x = 4
138
m = 13
139
T_x = 16
140
n_a = 3
141
n_y = 2
142
x_tmp = np.random.randn(n_x, m, T_x)
143
a0_tmp = np.random.randn(n_a, m)
144
parameters_tmp = {}
145
parameters_tmp['Wf'] = np.random.randn(n_a, n_a + n_x)
146
parameters_tmp['bf'] = np.random.randn(n_a, 1)
147
parameters_tmp['Wi'] = np.random.randn(n_a, n_a + n_x)
148
parameters_tmp['bi']= np.random.randn(n_a, 1)
149
parameters_tmp['Wo'] = np.random.randn(n_a, n_a + n_x)
150
parameters_tmp['bo'] = np.random.randn(n_a, 1)
151
parameters_tmp['Wc'] = np.random.randn(n_a, n_a + n_x)
152
parameters_tmp['bc'] = np.random.randn(n_a, 1)
153
parameters_tmp['Wy'] = np.random.randn(n_y, n_a)
154
parameters_tmp['by'] = np.random.randn(n_y, 1)
155
156
a, y, c, caches = target(x_tmp, a0_tmp, parameters_tmp)
157
158
assert a.shape == (n_a, m, T_x), f"Wrong shape for a. {a.shape} != {(n_a, m, T_x)}"
159
assert c.shape == (n_a, m, T_x), f"Wrong shape for c. {c.shape} != {(n_a, m, T_x)}"
160
assert y.shape == (n_y, m, T_x), f"Wrong shape for y. {y.shape} != {(n_y, m, T_x)}"
161
assert len(caches[0]) == T_x, f"Wrong shape for caches. {len(caches[0])} != {T_x} "
162
assert len(caches[0][0]) == 10, f"length of caches[0][0] must be 10."
163
164
assert np.allclose(a[2, 1, 4:6], [-0.01606022, 0.0243569]), "Wrong values for a"
165
assert np.allclose(c[2, 1, 4:6], [-0.02753855, 0.05668358]), "Wrong values for c"
166
assert np.allclose(y[1, 1, 4:6], [0.70444592 ,0.70648935]), "Wrong values for y"
167
168
print("\033[92mAll tests passed")
169