Path: blob/master/1_Supervised_Machine_Learning/Week 3. Classification/plt_quad_logistic.py
2826 views
"""1plt_quad_logistic.py2interactive plot and supporting routines showing logistic regression3"""45import time6from matplotlib import cm7import matplotlib.colors as colors8from matplotlib.gridspec import GridSpec9from matplotlib.widgets import Button10from matplotlib.patches import FancyArrowPatch11from ipywidgets import Output12from lab_utils_common import np, plt, dlc, dlcolors, sigmoid, compute_cost_matrix, gradient_descent1314# for debug15#output = Output() # sends hidden error messages to display when using widgets16#display(output)1718class plt_quad_logistic:19''' plots a quad plot showing logistic regression '''20# pylint: disable=too-many-instance-attributes21# pylint: disable=too-many-locals22# pylint: disable=missing-function-docstring23# pylint: disable=attribute-defined-outside-init24def __init__(self, x_train,y_train, w_range, b_range):25# setup figure26fig = plt.figure( figsize=(10,6))27fig.canvas.toolbar_visible = False28fig.canvas.header_visible = False29fig.canvas.footer_visible = False30fig.set_facecolor('#ffffff') #white31gs = GridSpec(2, 2, figure=fig)32ax0 = fig.add_subplot(gs[0, 0])33ax1 = fig.add_subplot(gs[0, 1])34ax2 = fig.add_subplot(gs[1, 0], projection='3d')35ax3 = fig.add_subplot(gs[1,1])36pos = ax3.get_position().get_points() ##[[lb_x,lb_y], [rt_x, rt_y]]37h = 0.0538width = 0.239axcalc = plt.axes([pos[1,0]-width, pos[1,1]-h, width, h]) #lx,by,w,h40ax = np.array([ax0, ax1, ax2, ax3, axcalc])41self.fig = fig42self.ax = ax43self.x_train = x_train44self.y_train = y_train4546self.w = 0. #initial point, non-array47self.b = 0.4849# initialize subplots50self.dplot = data_plot(ax[0], x_train, y_train, self.w, self.b)51self.con_plot = contour_and_surface_plot(ax[1], ax[2], x_train, y_train, w_range, b_range, self.w, self.b)52self.cplot = cost_plot(ax[3])5354# setup events55self.cid = fig.canvas.mpl_connect('button_press_event', self.click_contour)56self.bcalc = Button(axcalc, 'Run Gradient Descent \nfrom current w,b (click)', color=dlc["dlorange"])57self.bcalc.on_clicked(self.calc_logistic)5859# @output.capture() # debug60def click_contour(self, event):61''' called when click in contour '''62if event.inaxes == self.ax[1]: #contour plot63self.w = event.xdata64self.b = event.ydata6566self.cplot.re_init()67self.dplot.update(self.w, self.b)68self.con_plot.update_contour_wb_lines(self.w, self.b)69self.con_plot.path.re_init(self.w, self.b)7071self.fig.canvas.draw()7273# @output.capture() # debug74def calc_logistic(self, event):75''' called on run gradient event '''76for it in [1, 8,16,32,64,128,256,512,1024,2048,4096]:77w, self.b, J_hist = gradient_descent(self.x_train.reshape(-1,1), self.y_train.reshape(-1,1),78np.array(self.w).reshape(-1,1), self.b, 0.1, it,79logistic=True, lambda_=0, verbose=False)80self.w = w[0,0]81self.dplot.update(self.w, self.b)82self.con_plot.update_contour_wb_lines(self.w, self.b)83self.con_plot.path.add_path_item(self.w,self.b)84self.cplot.add_cost(J_hist)8586time.sleep(0.3)87self.fig.canvas.draw()888990class data_plot:91''' handles data plot '''92# pylint: disable=missing-function-docstring93# pylint: disable=attribute-defined-outside-init94def __init__(self, ax, x_train, y_train, w, b):95self.ax = ax96self.x_train = x_train97self.y_train = y_train98self.m = x_train.shape[0]99self.w = w100self.b = b101102self.plt_tumor_data()103self.draw_logistic_lines(firsttime=True)104self.mk_cost_lines(firsttime=True)105106self.ax.autoscale(enable=False) # leave plot scales the same after initial setup107108def plt_tumor_data(self):109x = self.x_train110y = self.y_train111pos = y == 1112neg = y == 0113self.ax.scatter(x[pos], y[pos], marker='x', s=80, c = 'red', label="malignant")114self.ax.scatter(x[neg], y[neg], marker='o', s=100, label="benign", facecolors='none',115edgecolors=dlc["dlblue"],lw=3)116self.ax.set_ylim(-0.175,1.1)117self.ax.set_ylabel('y')118self.ax.set_xlabel('Tumor Size')119self.ax.set_title("Logistic Regression on Categorical Data")120121def update(self, w, b):122self.w = w123self.b = b124self.draw_logistic_lines()125self.mk_cost_lines()126127def draw_logistic_lines(self, firsttime=False):128if not firsttime:129self.aline[0].remove()130self.bline[0].remove()131self.alegend.remove()132133xlim = self.ax.get_xlim()134x_hat = np.linspace(*xlim, 30)135y_hat = sigmoid(np.dot(x_hat.reshape(-1,1), self.w) + self.b)136self.aline = self.ax.plot(x_hat, y_hat, color=dlc["dlblue"],137label="y = sigmoid(z)")138f_wb = np.dot(x_hat.reshape(-1,1), self.w) + self.b139self.bline = self.ax.plot(x_hat, f_wb, color=dlc["dlorange"], lw=1,140label=f"z = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")141self.alegend = self.ax.legend(loc='upper left')142143def mk_cost_lines(self, firsttime=False):144''' makes vertical cost lines'''145if not firsttime:146for artist in self.cost_items:147artist.remove()148self.cost_items = []149cstr = f"cost = (1/{self.m})*("150ctot = 0151label = 'cost for point'152addedbreak = False153for p in zip(self.x_train,self.y_train):154f_wb_p = sigmoid(self.w*p[0]+self.b)155c_p = compute_cost_matrix(p[0].reshape(-1,1), p[1],np.array(self.w), self.b, logistic=True, lambda_=0, safe=True)156c_p_txt = c_p157a = self.ax.vlines(p[0], p[1],f_wb_p, lw=3, color=dlc["dlpurple"], ls='dotted', label=label)158label='' #just one159cxy = [p[0], p[1] + (f_wb_p-p[1])/2]160b = self.ax.annotate(f'{c_p_txt:0.1f}', xy=cxy, xycoords='data',color=dlc["dlpurple"],161xytext=(5, 0), textcoords='offset points')162cstr += f"{c_p_txt:0.1f} +"163if len(cstr) > 38 and addedbreak is False:164cstr += "\n"165addedbreak = True166ctot += c_p167self.cost_items.extend((a,b))168ctot = ctot/(len(self.x_train))169cstr = cstr[:-1] + f") = {ctot:0.2f}"170## todo.. figure out how to get this textbox to extend to the width of the subplot171c = self.ax.text(0.05,0.02,cstr, transform=self.ax.transAxes, color=dlc["dlpurple"])172self.cost_items.append(c)173174175class contour_and_surface_plot:176''' plots combined in class as they have similar operations '''177# pylint: disable=missing-function-docstring178# pylint: disable=attribute-defined-outside-init179def __init__(self, axc, axs, x_train, y_train, w_range, b_range, w, b):180181self.x_train = x_train182self.y_train = y_train183self.axc = axc184self.axs = axs185186#setup useful ranges and common linspaces187b_space = np.linspace(*b_range, 100)188w_space = np.linspace(*w_range, 100)189190# get cost for w,b ranges for contour and 3D191tmp_b,tmp_w = np.meshgrid(b_space,w_space)192z = np.zeros_like(tmp_b)193for i in range(tmp_w.shape[0]):194for j in range(tmp_w.shape[1]):195z[i,j] = compute_cost_matrix(x_train.reshape(-1,1), y_train, tmp_w[i,j], tmp_b[i,j],196logistic=True, lambda_=0, safe=True)197if z[i,j] == 0:198z[i,j] = 1e-9199200### plot contour ###201CS = axc.contour(tmp_w, tmp_b, np.log(z),levels=12, linewidths=2, alpha=0.7,colors=dlcolors)202axc.set_title('log(Cost(w,b))')203axc.set_xlabel('w', fontsize=10)204axc.set_ylabel('b', fontsize=10)205axc.set_xlim(w_range)206axc.set_ylim(b_range)207self.update_contour_wb_lines(w, b, firsttime=True)208axc.text(0.7,0.05,"Click to choose w,b", bbox=dict(facecolor='white', ec = 'black'), fontsize = 10,209transform=axc.transAxes, verticalalignment = 'center', horizontalalignment= 'center')210211#Surface plot of the cost function J(w,b)212axs.plot_surface(tmp_w, tmp_b, z, cmap = cm.jet, alpha=0.3, antialiased=True)213axs.plot_wireframe(tmp_w, tmp_b, z, color='k', alpha=0.1)214axs.set_xlabel("$w$")215axs.set_ylabel("$b$")216axs.zaxis.set_rotate_label(False)217axs.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))218axs.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))219axs.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))220axs.set_zlabel("J(w, b)", rotation=90)221axs.view_init(30, -120)222223axs.autoscale(enable=False)224axc.autoscale(enable=False)225226self.path = path(self.w,self.b, self.axc) # initialize an empty path, avoids existance check227228def update_contour_wb_lines(self, w, b, firsttime=False):229self.w = w230self.b = b231cst = compute_cost_matrix(self.x_train.reshape(-1,1), self.y_train, np.array(self.w), self.b,232logistic=True, lambda_=0, safe=True)233234# remove lines and re-add on contour plot and 3d plot235if not firsttime:236for artist in self.dyn_items:237artist.remove()238a = self.axc.scatter(self.w, self.b, s=100, color=dlc["dlblue"], zorder= 10, label="cost with \ncurrent w,b")239b = self.axc.hlines(self.b, self.axc.get_xlim()[0], self.w, lw=4, color=dlc["dlpurple"], ls='dotted')240c = self.axc.vlines(self.w, self.axc.get_ylim()[0] ,self.b, lw=4, color=dlc["dlpurple"], ls='dotted')241d = self.axc.annotate(f"Cost: {cst:0.2f}", xy= (self.w, self.b), xytext = (4,4), textcoords = 'offset points',242bbox=dict(facecolor='white'), size = 10)243#Add point in 3D surface plot244e = self.axs.scatter3D(self.w, self.b, cst , marker='X', s=100)245246self.dyn_items = [a,b,c,d,e]247248249class cost_plot:250""" manages cost plot for plt_quad_logistic """251# pylint: disable=missing-function-docstring252# pylint: disable=attribute-defined-outside-init253def __init__(self,ax):254self.ax = ax255self.ax.set_ylabel("log(cost)")256self.ax.set_xlabel("iteration")257self.costs = []258self.cline = self.ax.plot(0,0, color=dlc["dlblue"])259260def re_init(self):261self.ax.clear()262self.__init__(self.ax)263264def add_cost(self,J_hist):265self.costs.extend(J_hist)266self.cline[0].remove()267self.cline = self.ax.plot(self.costs)268269class path:270''' tracks paths during gradient descent on contour plot '''271# pylint: disable=missing-function-docstring272# pylint: disable=attribute-defined-outside-init273def __init__(self, w, b, ax):274''' w, b at start of path '''275self.path_items = []276self.w = w277self.b = b278self.ax = ax279280def re_init(self, w, b):281for artist in self.path_items:282artist.remove()283self.path_items = []284self.w = w285self.b = b286287def add_path_item(self, w, b):288a = FancyArrowPatch(289posA=(self.w, self.b), posB=(w, b), color=dlc["dlblue"],290arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',291)292self.ax.add_artist(a)293self.path_items.append(a)294self.w = w295self.b = b296297#-----------298# related to the logistic gradient descent lab299#----------300301def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):302""" truncates color map """303new_cmap = colors.LinearSegmentedColormap.from_list(304'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),305cmap(np.linspace(minval, maxval, n)))306return new_cmap307308def plt_prob(ax, w_out,b_out):309""" plots a decision boundary but include shading to indicate the probability """310#setup useful ranges and common linspaces311x0_space = np.linspace(0, 4 , 100)312x1_space = np.linspace(0, 4 , 100)313314# get probability for x0,x1 ranges315tmp_x0,tmp_x1 = np.meshgrid(x0_space,x1_space)316z = np.zeros_like(tmp_x0)317for i in range(tmp_x0.shape[0]):318for j in range(tmp_x1.shape[1]):319z[i,j] = sigmoid(np.dot(w_out, np.array([tmp_x0[i,j],tmp_x1[i,j]])) + b_out)320321322cmap = plt.get_cmap('Blues')323new_cmap = truncate_colormap(cmap, 0.0, 0.5)324pcm = ax.pcolormesh(tmp_x0, tmp_x1, z,325norm=cm.colors.Normalize(vmin=0, vmax=1),326cmap=new_cmap, shading='nearest', alpha = 0.9)327ax.figure.colorbar(pcm, ax=ax)328329330