Path: blob/master/1_Supervised_Machine_Learning/Week 3. Classification/plt_one_addpt_onclick.py
2826 views
import time1import copy2from ipywidgets import Output3from matplotlib.widgets import Button, CheckButtons4from matplotlib.patches import FancyArrowPatch5from lab_utils_common import np, plt, dlblue, dlorange, sigmoid, dldarkred, gradient_descent67# for debug8#output = Output() # sends hidden error messages to display when using widgets9#display(output)1011class plt_one_addpt_onclick:12""" class to run one interactive plot """13def __init__(self, x, y, w, b, logistic=True):14self.logistic=logistic15pos = y == 116neg = y == 01718fig,ax = plt.subplots(1,1,figsize=(8,4))19fig.canvas.toolbar_visible = False20fig.canvas.header_visible = False21fig.canvas.footer_visible = False2223plt.subplots_adjust(bottom=0.25)24ax.scatter(x[pos], y[pos], marker='x', s=80, c = 'red', label="malignant")25ax.scatter(x[neg], y[neg], marker='o', s=100, label="benign", facecolors='none', edgecolors=dlblue,lw=3)26ax.set_ylim(-0.05,1.1)27xlim = ax.get_xlim()28ax.set_xlim(xlim[0],xlim[1]*2)29ax.set_ylabel('y')30ax.set_xlabel('Tumor Size')31self.alegend = ax.legend(loc='lower right')32if self.logistic:33ax.set_title("Example of Logistic Regression on Categorical Data")34else:35ax.set_title("Example of Linear Regression on Categorical Data")3637ax.text(0.65,0.8,"[Click to add data points]", size=10, transform=ax.transAxes)3839axcalc = plt.axes([0.1, 0.05, 0.38, 0.075]) #l,b,w,h40axthresh = plt.axes([0.5, 0.05, 0.38, 0.075]) #l,b,w,h41self.tlist = []4243self.fig = fig44self.ax = [ax,axcalc,axthresh]45self.x = x46self.y = y47self.w = copy.deepcopy(w)48self.b = b49f_wb = np.matmul(self.x.reshape(-1,1), self.w) + self.b50if self.logistic:51self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue)52self.bline = self.ax[0].plot(self.x, f_wb, color=dlorange,lw=1)53else:54self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue)5556self.cid = fig.canvas.mpl_connect('button_press_event', self.add_data)57if self.logistic:58self.bcalc = Button(axcalc, 'Run Logistic Regression (click)', color=dlblue)59self.bcalc.on_clicked(self.calc_logistic)60else:61self.bcalc = Button(axcalc, 'Run Linear Regression (click)', color=dlblue)62self.bcalc.on_clicked(self.calc_linear)63self.bthresh = CheckButtons(axthresh, ('Toggle 0.5 threshold (after regression)',))64self.bthresh.on_clicked(self.thresh)65self.resize_sq(self.bthresh)6667# @output.capture() # debug68def add_data(self, event):69#self.ax[0].text(0.1,0.1, f"in onclick")70if event.inaxes == self.ax[0]:71x_coord = event.xdata72y_coord = event.ydata7374if y_coord > 0.5:75self.ax[0].scatter(x_coord, 1, marker='x', s=80, c = 'red' )76self.y = np.append(self.y,1)77else:78self.ax[0].scatter(x_coord, 0, marker='o', s=100, facecolors='none', edgecolors=dlblue,lw=3)79self.y = np.append(self.y,0)80self.x = np.append(self.x,x_coord)81self.fig.canvas.draw()8283# @output.capture() # debug84def calc_linear(self, event):85if self.bthresh.get_status()[0]:86self.remove_thresh()87for it in [1,1,1,1,1,2,4,8,16,32,64,128,256]:88self.w, self.b, _ = gradient_descent(self.x.reshape(-1,1), self.y.reshape(-1,1),89self.w.reshape(-1,1), self.b, 0.01, it,90logistic=False, lambda_=0, verbose=False)91self.aline[0].remove()92self.alegend.remove()93y_hat = np.matmul(self.x.reshape(-1,1), self.w) + self.b94self.aline = self.ax[0].plot(self.x, y_hat, color=dlblue,95label=f"y = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")96self.alegend = self.ax[0].legend(loc='lower right')97time.sleep(0.3)98self.fig.canvas.draw()99if self.bthresh.get_status()[0]:100self.draw_thresh()101self.fig.canvas.draw()102103def calc_logistic(self, event):104if self.bthresh.get_status()[0]:105self.remove_thresh()106for it in [1, 8,16,32,64,128,256,512,1024,2048,4096]:107self.w, self.b, _ = gradient_descent(self.x.reshape(-1,1), self.y.reshape(-1,1),108self.w.reshape(-1,1), self.b, 0.1, it,109logistic=True, lambda_=0, verbose=False)110self.aline[0].remove()111self.bline[0].remove()112self.alegend.remove()113xlim = self.ax[0].get_xlim()114x_hat = np.linspace(*xlim, 30)115y_hat = sigmoid(np.matmul(x_hat.reshape(-1,1), self.w) + self.b)116self.aline = self.ax[0].plot(x_hat, y_hat, color=dlblue,117label="y = sigmoid(z)")118f_wb = np.matmul(x_hat.reshape(-1,1), self.w) + self.b119self.bline = self.ax[0].plot(x_hat, f_wb, color=dlorange, lw=1,120label=f"z = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")121self.alegend = self.ax[0].legend(loc='lower right')122time.sleep(0.3)123self.fig.canvas.draw()124if self.bthresh.get_status()[0]:125self.draw_thresh()126self.fig.canvas.draw()127128129def thresh(self, event):130if self.bthresh.get_status()[0]:131#plt.figtext(0,0, f"in thresh {self.bthresh.get_status()}")132self.draw_thresh()133else:134#plt.figtext(0,0.3, f"in thresh {self.bthresh.get_status()}")135self.remove_thresh()136137def draw_thresh(self):138ws = np.squeeze(self.w)139xp5 = -self.b/ws if self.logistic else (0.5 - self.b) / ws140ylim = self.ax[0].get_ylim()141xlim = self.ax[0].get_xlim()142a = self.ax[0].fill_between([xlim[0], xp5], [ylim[1], ylim[1]], alpha=0.2, color=dlblue)143b = self.ax[0].fill_between([xp5, xlim[1]], [ylim[1], ylim[1]], alpha=0.2, color=dldarkred)144c = self.ax[0].annotate("Malignant", xy= [xp5,0.5], xycoords='data',145xytext=[30,5],textcoords='offset points')146d = FancyArrowPatch(147posA=(xp5, 0.5), posB=(xp5+1.5, 0.5), color=dldarkred,148arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',149)150self.ax[0].add_artist(d)151152e = self.ax[0].annotate("Benign", xy= [xp5,0.5], xycoords='data',153xytext=[-70,5],textcoords='offset points', ha='left')154f = FancyArrowPatch(155posA=(xp5, 0.5), posB=(xp5-1.5, 0.5), color=dlblue,156arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',157)158self.ax[0].add_artist(f)159self.tlist = [a,b,c,d,e,f]160161self.fig.canvas.draw()162163def remove_thresh(self):164#plt.figtext(0.5,0.0, f"rem thresh {self.bthresh.get_status()}")165for artist in self.tlist:166artist.remove()167self.fig.canvas.draw()168169def resize_sq(self, bcid):170""" resizes the check box """171#future reference172#print(f"width : {bcid.rectangles[0].get_width()}")173#print(f"height : {bcid.rectangles[0].get_height()}")174#print(f"xy : {bcid.rectangles[0].get_xy()}")175#print(f"bb : {bcid.rectangles[0].get_bbox()}")176#print(f"points : {bcid.rectangles[0].get_bbox().get_points()}") #[[xmin,ymin],[xmax,ymax]]177178h = bcid.rectangles[0].get_height()179bcid.rectangles[0].set_height(3*h)180181ymax = bcid.rectangles[0].get_bbox().y1182ymin = bcid.rectangles[0].get_bbox().y0183184bcid.lines[0][0].set_ydata([ymax,ymin])185bcid.lines[0][1].set_ydata([ymin,ymax])186187188