Path: blob/main/C1 - Supervised Machine Learning - Regression and Classification/week2/C1W2A1/public_tests.py
3748 views
import numpy as np12def compute_cost_test(target):3# print("Using X with shape (4, 1)")4# Case 15x = np.array([2, 4, 6, 8]).T6y = np.array([7, 11, 15, 19]).T7initial_w = 28initial_b = 3.09cost = target(x, y, initial_w, initial_b)10assert cost == 0, f"Case 1: Cost must be 0 for a perfect prediction but got {cost}"1112# Case 213x = np.array([2, 4, 6, 8]).T14y = np.array([7, 11, 15, 19]).T15initial_w = 2.016initial_b = 1.017cost = target(x, y, initial_w, initial_b)18assert cost == 2, f"Case 2: Cost must be 2 but got {cost}"1920# print("Using X with shape (5, 1)")21# Case 322x = np.array([1.5, 2.5, 3.5, 4.5, 1.5]).T23y = np.array([4, 7, 10, 13, 5]).T24initial_w = 125initial_b = 0.026cost = target(x, y, initial_w, initial_b)27assert np.isclose(cost, 15.325), f"Case 3: Cost must be 15.325 for a perfect prediction but got {cost}"2829# Case 430initial_b = 1.031cost = target(x, y, initial_w, initial_b)32assert np.isclose(cost, 10.725), f"Case 4: Cost must be 10.725 but got {cost}"3334# Case 535y = y - 236initial_b = 1.037cost = target(x, y, initial_w, initial_b)38assert np.isclose(cost, 4.525), f"Case 5: Cost must be 4.525 but got {cost}"3940print("\033[92mAll tests passed!")4142def compute_gradient_test(target):43print("Using X with shape (4, 1)")44# Case 145x = np.array([2, 4, 6, 8]).T46y = np.array([4.5, 8.5, 12.5, 16.5]).T47initial_w = 2.48initial_b = 0.549dj_dw, dj_db = target(x, y, initial_w, initial_b)50#assert dj_dw.shape == initial_w.shape, f"Wrong shape for dj_dw. {dj_dw} != {initial_w.shape}"51assert dj_db == 0.0, f"Case 1: dj_db is wrong: {dj_db} != 0.0"52assert np.allclose(dj_dw, 0), f"Case 1: dj_dw is wrong: {dj_dw} != [[0.0]]"5354# Case 255x = np.array([2, 4, 6, 8]).T56y = np.array([4, 7, 10, 13]).T + 257initial_w = 1.558initial_b = 159dj_dw, dj_db = target(x, y, initial_w, initial_b)60#assert dj_dw.shape == initial_w.shape, f"Wrong shape for dj_dw. {dj_dw} != {initial_w.shape}"61assert dj_db == -2, f"Case 1: dj_db is wrong: {dj_db} != -2"62assert np.allclose(dj_dw, -10.0), f"Case 1: dj_dw is wrong: {dj_dw} != -10.0"6364print("\033[92mAll tests passed!")6566676869