Path: blob/master/1_Supervised_Machine_Learning/Week 3. Classification/plt_logistic_loss.py
2826 views
"""----------------------------------------------------------------1logistic_loss plotting routines and support2"""34from matplotlib import cm5from lab_utils_common import sigmoid, dlblue, dlorange, np, plt, compute_cost_matrix67def compute_cost_logistic_sq_err(X, y, w, b):8"""9compute sq error cost on logicist data (for negative example only, not used in practice)10Args:11X (ndarray): Shape (m,n) matrix of examples with multiple features12w (ndarray): Shape (n) parameters for prediction13b (scalar): parameter for prediction14Returns:15cost (scalar): cost16"""17m = X.shape[0]18cost = 0.019for i in range(m):20z_i = np.dot(X[i],w) + b21f_wb_i = sigmoid(z_i) #add sigmoid to normal sq error cost for linear regression22cost = cost + (f_wb_i - y[i])**223cost = cost / (2 * m)24return np.squeeze(cost)2526def plt_logistic_squared_error(X,y):27""" plots logistic squared error for demonstration """28wx, by = np.meshgrid(np.linspace(-6,12,50),29np.linspace(10, -20, 40))30points = np.c_[wx.ravel(), by.ravel()]31cost = np.zeros(points.shape[0])3233for i in range(points.shape[0]):34w,b = points[i]35cost[i] = compute_cost_logistic_sq_err(X.reshape(-1,1), y, w, b)36cost = cost.reshape(wx.shape)3738fig = plt.figure()39fig.canvas.toolbar_visible = False40fig.canvas.header_visible = False41fig.canvas.footer_visible = False42ax = fig.add_subplot(1, 1, 1, projection='3d')43ax.plot_surface(wx, by, cost, alpha=0.6,cmap=cm.jet,)4445ax.set_xlabel('w', fontsize=16)46ax.set_ylabel('b', fontsize=16)47ax.set_zlabel("Cost", rotation=90, fontsize=16)48ax.set_title('"Logistic" Squared Error Cost vs (w, b)')49ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))50ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))51ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))525354def plt_logistic_cost(X,y):55""" plots logistic cost """56wx, by = np.meshgrid(np.linspace(-6,12,50),57np.linspace(0, -20, 40))58points = np.c_[wx.ravel(), by.ravel()]59cost = np.zeros(points.shape[0],dtype=np.longdouble)6061for i in range(points.shape[0]):62w,b = points[i]63cost[i] = compute_cost_matrix(X.reshape(-1,1), y, w, b, logistic=True, safe=True)64cost = cost.reshape(wx.shape)6566fig = plt.figure(figsize=(9,5))67fig.canvas.toolbar_visible = False68fig.canvas.header_visible = False69fig.canvas.footer_visible = False70ax = fig.add_subplot(1, 2, 1, projection='3d')71ax.plot_surface(wx, by, cost, alpha=0.6,cmap=cm.jet,)7273ax.set_xlabel('w', fontsize=16)74ax.set_ylabel('b', fontsize=16)75ax.set_zlabel("Cost", rotation=90, fontsize=16)76ax.set_title('Logistic Cost vs (w, b)')77ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))78ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))79ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))8081ax = fig.add_subplot(1, 2, 2, projection='3d')8283ax.plot_surface(wx, by, np.log(cost), alpha=0.6,cmap=cm.jet,)8485ax.set_xlabel('w', fontsize=16)86ax.set_ylabel('b', fontsize=16)87ax.set_zlabel('\nlog(Cost)', fontsize=16)88ax.set_title('log(Logistic Cost) vs (w, b)')89ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))90ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))91ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))9293plt.show()94return cost959697def soup_bowl():98""" creates 3D quadratic error surface """99#Create figure and plot with a 3D projection100fig = plt.figure(figsize=(4,4))101fig.canvas.toolbar_visible = False102fig.canvas.header_visible = False103fig.canvas.footer_visible = False104105#Plot configuration106ax = fig.add_subplot(111, projection='3d')107ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))108ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))109ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))110ax.zaxis.set_rotate_label(False)111ax.view_init(15, -120)112113#Useful linearspaces to give values to the parameters w and b114w = np.linspace(-20, 20, 100)115b = np.linspace(-20, 20, 100)116117#Get the z value for a bowl-shaped cost function118z=np.zeros((len(w), len(b)))119j=0120for x in w:121i=0122for y in b:123z[i,j] = x**2 + y**2124i+=1125j+=1126127#Meshgrid used for plotting 3D functions128W, B = np.meshgrid(w, b)129130#Create the 3D surface plot of the bowl-shaped cost function131ax.plot_surface(W, B, z, cmap = "Spectral_r", alpha=0.7, antialiased=False)132ax.plot_wireframe(W, B, z, color='k', alpha=0.1)133ax.set_xlabel("$w$")134ax.set_ylabel("$b$")135ax.set_zlabel("Cost", rotation=90)136ax.set_title("Squared Error Cost used in Linear Regression")137138plt.show()139140141def plt_simple_example(x, y):142""" plots tumor data """143pos = y == 1144neg = y == 0145146fig,ax = plt.subplots(1,1,figsize=(5,3))147fig.canvas.toolbar_visible = False148fig.canvas.header_visible = False149fig.canvas.footer_visible = False150151ax.scatter(x[pos], y[pos], marker='x', s=80, c = 'red', label="malignant")152ax.scatter(x[neg], y[neg], marker='o', s=100, label="benign", facecolors='none', edgecolors=dlblue,lw=3)153ax.set_ylim(-0.075,1.1)154ax.set_ylabel('y')155ax.set_xlabel('Tumor Size')156ax.legend(loc='lower right')157ax.set_title("Example of Logistic Regression on Categorical Data")158159160def plt_two_logistic_loss_curves():161""" plots the logistic loss """162fig,ax = plt.subplots(1,2,figsize=(6,3),sharey=True)163fig.canvas.toolbar_visible = False164fig.canvas.header_visible = False165fig.canvas.footer_visible = False166x = np.linspace(0.01,1-0.01,20)167ax[0].plot(x,-np.log(x))168#ax[0].set_title("y = 1")169ax[0].text(0.5, 4.0, "y = 1", fontsize=12)170ax[0].set_ylabel("loss")171ax[0].set_xlabel(r"$f_{w,b}(x)$")172ax[1].plot(x,-np.log(1-x))173#ax[1].set_title("y = 0")174ax[1].text(0.5, 4.0, "y = 0", fontsize=12)175ax[1].set_xlabel(r"$f_{w,b}(x)$")176ax[0].annotate("prediction \nmatches \ntarget ", xy= [1,0], xycoords='data',177xytext=[-10,30],textcoords='offset points', ha="right", va="center",178arrowprops={'arrowstyle': '->', 'color': dlorange, 'lw': 3},)179ax[0].annotate("loss increases as prediction\n differs from target", xy= [0.1,-np.log(0.1)], xycoords='data',180xytext=[10,30],textcoords='offset points', ha="left", va="center",181arrowprops={'arrowstyle': '->', 'color': dlorange, 'lw': 3},)182ax[1].annotate("prediction \nmatches \ntarget ", xy= [0,0], xycoords='data',183xytext=[10,30],textcoords='offset points', ha="left", va="center",184arrowprops={'arrowstyle': '->', 'color': dlorange, 'lw': 3},)185ax[1].annotate("loss increases as prediction\n differs from target", xy= [0.9,-np.log(1-0.9)], xycoords='data',186xytext=[-10,30],textcoords='offset points', ha="right", va="center",187arrowprops={'arrowstyle': '->', 'color': dlorange, 'lw': 3},)188plt.suptitle("Loss Curves for Two Categorical Target Values", fontsize=12)189plt.tight_layout()190plt.show()191192193