Path: blob/master/1_Supervised_Machine_Learning/Week 3. Classification/plt_overfit.py
2826 views
"""1plot_overfit2class and assocaited routines that plot an interactive example of overfitting and its solutions3"""4import math5from ipywidgets import Output6from matplotlib.gridspec import GridSpec7from matplotlib.widgets import Button, CheckButtons8from sklearn.linear_model import LogisticRegression, Ridge9from lab_utils_common import np, plt, dlc, predict_logistic, plot_data, zscore_normalize_features1011def map_one_feature(X1, degree):12"""13Feature mapping function to polynomial features14"""15X1 = np.atleast_1d(X1)16out = []17string = ""18k = 019for i in range(1, degree+1):20out.append((X1**i))21string = string + f"w_{{{k}}}{munge('x_0',i)} + "22k += 123string = string + ' b' #add b to text equation, not to data24return np.stack(out, axis=1), string252627def map_feature(X1, X2, degree):28"""29Feature mapping function to polynomial features30"""31X1 = np.atleast_1d(X1)32X2 = np.atleast_1d(X2)3334out = []35string = ""36k = 037for i in range(1, degree+1):38for j in range(i + 1):39out.append((X1**(i-j) * (X2**j)))40string = string + f"w_{{{k}}}{munge('x_0',i-j)}{munge('x_1',j)} + "41k += 142#print(string + 'b')43return np.stack(out, axis=1), string + ' b'4445def munge(base, exp):46if exp == 0:47return ''48if exp == 1:49return base50return base + f'^{{{exp}}}'5152def plot_decision_boundary(ax, x0r,x1r, predict, w, b, scaler = False, mu=None, sigma=None, degree=None):53"""54Plots a decision boundary55Args:56x0r : (array_like Shape (1,1)) range (min, max) of x057x1r : (array_like Shape (1,1)) range (min, max) of x158predict : function to predict z values59scalar : (boolean) scale data or not60"""6162h = .01 # step size in the mesh63# create a mesh to plot in64xx, yy = np.meshgrid(np.arange(x0r[0], x0r[1], h),65np.arange(x1r[0], x1r[1], h))6667# Plot the decision boundary. For that, we will assign a color to each68# point in the mesh [x_min, m_max]x[y_min, y_max].69points = np.c_[xx.ravel(), yy.ravel()]70Xm,_ = map_feature(points[:, 0], points[:, 1],degree)71if scaler:72Xm = (Xm - mu)/sigma73Z = predict(Xm, w, b)7475# Put the result into a color plot76Z = Z.reshape(xx.shape)77contour = ax.contour(xx, yy, Z, levels = [0.5], colors='g')78return contour7980# use this to test the above routine81def plot_decision_boundary_sklearn(x0r, x1r, predict, degree, scaler = False):82"""83Plots a decision boundary84Args:85x0r : (array_like Shape (1,1)) range (min, max) of x086x1r : (array_like Shape (1,1)) range (min, max) of x187degree: (int) degree of polynomial88predict : function to predict z values89scaler : not sure90"""9192h = .01 # step size in the mesh93# create a mesh to plot in94xx, yy = np.meshgrid(np.arange(x0r[0], x0r[1], h),95np.arange(x1r[0], x1r[1], h))9697# Plot the decision boundary. For that, we will assign a color to each98# point in the mesh [x_min, m_max]x[y_min, y_max].99points = np.c_[xx.ravel(), yy.ravel()]100Xm = map_feature(points[:, 0], points[:, 1],degree)101if scaler:102Xm = scaler.transform(Xm)103Z = predict(Xm)104105# Put the result into a color plot106Z = Z.reshape(xx.shape)107plt.contour(xx, yy, Z, colors='g')108#plot_data(X_train,y_train)109110#for debug, uncomment the #@output statments below for routines you want to get error output from111# In the notebook that will call these routines, import `output`112# from plt_overfit import overfit_example, output113# then, in a cell where the error messages will be the output of..114#display(output)115116output = Output() # sends hidden error messages to display when using widgets117118class button_manager:119''' Handles some missing features of matplotlib check buttons120on init:121creates button, links to button_click routine,122calls call_on_click with active index and firsttime=True123on click:124maintains single button on state, calls call_on_click125'''126127@output.capture() # debug128def __init__(self,fig, dim, labels, init, call_on_click):129'''130dim: (list) [leftbottom_x,bottom_y,width,height]131labels: (list) for example ['1','2','3','4','5','6']132init: (list) for example [True, False, False, False, False, False]133'''134self.fig = fig135self.ax = plt.axes(dim) #lx,by,w,h136self.init_state = init137self.call_on_click = call_on_click138self.button = CheckButtons(self.ax,labels,init)139self.button.on_clicked(self.button_click)140self.status = self.button.get_status()141self.call_on_click(self.status.index(True),firsttime=True)142143@output.capture() # debug144def reinit(self):145self.status = self.init_state146self.button.set_active(self.status.index(True)) #turn off old, will trigger update and set to status147148@output.capture() # debug149def button_click(self, event):150''' maintains one-on state. If on-button is clicked, will process correctly '''151#new_status = self.button.get_status()152#new = [self.status[i] ^ new_status[i] for i in range(len(self.status))]153#newidx = new.index(True)154self.button.eventson = False155self.button.set_active(self.status.index(True)) #turn off old or reenable if same156self.button.eventson = True157self.status = self.button.get_status()158self.call_on_click(self.status.index(True))159160class overfit_example():161""" plot overfit example """162# pylint: disable=too-many-instance-attributes163# pylint: disable=too-many-locals164# pylint: disable=missing-function-docstring165# pylint: disable=attribute-defined-outside-init166def __init__(self, regularize=False):167self.regularize=regularize168self.lambda_=0169fig = plt.figure( figsize=(8,6))170fig.canvas.toolbar_visible = False171fig.canvas.header_visible = False172fig.canvas.footer_visible = False173fig.set_facecolor('#ffffff') #white174gs = GridSpec(5, 3, figure=fig)175ax0 = fig.add_subplot(gs[0:3, :])176ax1 = fig.add_subplot(gs[-2, :])177ax2 = fig.add_subplot(gs[-1, :])178ax1.set_axis_off()179ax2.set_axis_off()180self.ax = [ax0,ax1,ax2]181self.fig = fig182183self.axfitdata = plt.axes([0.26,0.124,0.12,0.1 ]) #lx,by,w,h184self.bfitdata = Button(self.axfitdata , 'fit data', color=dlc['dlblue'])185self.bfitdata.label.set_fontsize(12)186self.bfitdata.on_clicked(self.fitdata_clicked)187188#clear data is a future enhancement189#self.axclrdata = plt.axes([0.26,0.06,0.12,0.05 ]) #lx,by,w,h190#self.bclrdata = Button(self.axclrdata , 'clear data', color='white')191#self.bclrdata.label.set_fontsize(12)192#self.bclrdata.on_clicked(self.clrdata_clicked)193194self.cid = fig.canvas.mpl_connect('button_press_event', self.add_data)195196self.typebut = button_manager(fig, [0.4, 0.07,0.15,0.15], ["Regression", "Categorical"],197[False,True], self.toggle_type)198199self.fig.text(0.1, 0.02+0.21, "Degree", fontsize=12)200self.degrbut = button_manager(fig,[0.1,0.02,0.15,0.2 ], ['1','2','3','4','5','6'],201[True, False, False, False, False, False], self.update_equation)202if self.regularize:203self.fig.text(0.6, 0.02+0.21, r"lambda($\lambda$)", fontsize=12)204self.lambut = button_manager(fig,[0.6,0.02,0.15,0.2 ], ['0.0','0.2','0.4','0.6','0.8','1'],205[True, False, False, False, False, False], self.updt_lambda)206207#self.regbut = button_manager(fig, [0.8, 0.08,0.24,0.15], ["Regularize"],208# [False], self.toggle_reg)209#self.logistic_data()210211def updt_lambda(self, idx, firsttime=False):212# pylint: disable=unused-argument213self.lambda_ = idx * 0.2214215def toggle_type(self, idx, firsttime=False):216self.logistic = idx==1217self.ax[0].clear()218if self.logistic:219self.logistic_data()220else:221self.linear_data()222if not firsttime:223self.degrbut.reinit()224225@output.capture() # debug226def logistic_data(self,redraw=False):227if not redraw:228m = 50229n = 2230np.random.seed(2)231X_train = 2*(np.random.rand(m,n)-[0.5,0.5])232y_train = X_train[:,1]+0.5 > X_train[:,0]**2 + 0.5*np.random.rand(m) #quadratic + random233y_train = y_train + 0 #convert from boolean to integer234self.X = X_train235self.y = y_train236self.x_ideal = np.sort(X_train[:,0])237self.y_ideal = self.x_ideal**2238239240self.ax[0].plot(self.x_ideal, self.y_ideal, "--", color = "orangered", label="ideal", lw=1)241plot_data(self.X, self.y, self.ax[0], s=10, loc='lower right')242self.ax[0].set_title("OverFitting Example: Categorical data set with noise")243self.ax[0].text(0.5,0.93, "Click on plot to add data. Hold [Shift] for blue(y=0) data.",244fontsize=12, ha='center',transform=self.ax[0].transAxes, color=dlc["dlblue"])245self.ax[0].set_xlabel(r"$x_0$")246self.ax[0].set_ylabel(r"$x_1$")247248def linear_data(self,redraw=False):249if not redraw:250m = 30251c = 0252x_train = np.arange(0,m,1)253np.random.seed(1)254y_ideal = x_train**2 + c255y_train = y_ideal + 0.7 * y_ideal*(np.random.sample((m,))-0.5)256self.x_ideal = x_train #for redraw when new data included in X257self.X = x_train258self.y = y_train259self.y_ideal = y_ideal260else:261self.ax[0].set_xlim(self.xlim)262self.ax[0].set_ylim(self.ylim)263264self.ax[0].scatter(self.X,self.y, label="y")265self.ax[0].plot(self.x_ideal, self.y_ideal, "--", color = "orangered", label="y_ideal", lw=1)266self.ax[0].set_title("OverFitting Example: Regression Data Set (quadratic with noise)",fontsize = 14)267self.ax[0].set_xlabel("x")268self.ax[0].set_ylabel("y")269self.ax0ledgend = self.ax[0].legend(loc='lower right')270self.ax[0].text(0.5,0.93, "Click on plot to add data",271fontsize=12, ha='center',transform=self.ax[0].transAxes, color=dlc["dlblue"])272if not redraw:273self.xlim = self.ax[0].get_xlim()274self.ylim = self.ax[0].get_ylim()275276277@output.capture() # debug278def add_data(self, event):279if self.logistic:280self.add_data_logistic(event)281else:282self.add_data_linear(event)283284@output.capture() # debug285def add_data_logistic(self, event):286if event.inaxes == self.ax[0]:287x0_coord = event.xdata288x1_coord = event.ydata289290if event.key is None: #shift not pressed291self.ax[0].scatter(x0_coord, x1_coord, marker='x', s=10, c = 'red', label="y=1")292self.y = np.append(self.y,1)293else:294self.ax[0].scatter(x0_coord, x1_coord, marker='o', s=10, label="y=0", facecolors='none',295edgecolors=dlc['dlblue'],lw=3)296self.y = np.append(self.y,0)297self.X = np.append(self.X,np.array([[x0_coord, x1_coord]]),axis=0)298self.fig.canvas.draw()299300def add_data_linear(self, event):301if event.inaxes == self.ax[0]:302x_coord = event.xdata303y_coord = event.ydata304305self.ax[0].scatter(x_coord, y_coord, marker='o', s=10, facecolors='none',306edgecolors=dlc['dlblue'],lw=3)307self.y = np.append(self.y,y_coord)308self.X = np.append(self.X,x_coord)309self.fig.canvas.draw()310311#@output.capture() # debug312#def clrdata_clicked(self,event):313# if self.logistic == True:314# self.X = np.315# else:316# self.linear_regression()317318319@output.capture() # debug320def fitdata_clicked(self,event):321if self.logistic:322self.logistic_regression()323else:324self.linear_regression()325326def linear_regression(self):327self.ax[0].clear()328self.fig.canvas.draw()329330# create and fit the model using our mapped_X feature set.331self.X_mapped, _ = map_one_feature(self.X, self.degree)332self.X_mapped_scaled, self.X_mu, self.X_sigma = zscore_normalize_features(self.X_mapped)333334#linear_model = LinearRegression()335linear_model = Ridge(alpha=self.lambda_, normalize=True, max_iter=10000)336linear_model.fit(self.X_mapped_scaled, self.y )337self.w = linear_model.coef_.reshape(-1,)338self.b = linear_model.intercept_339x = np.linspace(*self.xlim,30) #plot line idependent of data which gets disordered340xm, _ = map_one_feature(x, self.degree)341xms = (xm - self.X_mu)/ self.X_sigma342y_pred = linear_model.predict(xms)343344#self.fig.canvas.draw()345self.linear_data(redraw=True)346self.ax0yfit = self.ax[0].plot(x, y_pred, color = "blue", label="y_fit")347self.ax0ledgend = self.ax[0].legend(loc='lower right')348self.fig.canvas.draw()349350def logistic_regression(self):351self.ax[0].clear()352self.fig.canvas.draw()353354# create and fit the model using our mapped_X feature set.355self.X_mapped, _ = map_feature(self.X[:, 0], self.X[:, 1], self.degree)356self.X_mapped_scaled, self.X_mu, self.X_sigma = zscore_normalize_features(self.X_mapped)357if not self.regularize or self.lambda_ == 0:358lr = LogisticRegression(penalty='none', max_iter=10000)359else:360C = 1/self.lambda_361lr = LogisticRegression(C=C, max_iter=10000)362363lr.fit(self.X_mapped_scaled,self.y)364#print(lr.score(self.X_mapped_scaled, self.y))365self.w = lr.coef_.reshape(-1,)366self.b = lr.intercept_367#print(self.w, self.b)368self.logistic_data(redraw=True)369self.contour = plot_decision_boundary(self.ax[0],[-1,1],[-1,1], predict_logistic, self.w, self.b,370scaler=True, mu=self.X_mu, sigma=self.X_sigma, degree=self.degree )371self.fig.canvas.draw()372373@output.capture() # debug374def update_equation(self, idx, firsttime=False):375#print(f"Update equation, index = {idx}, firsttime={firsttime}")376self.degree = idx+1377if firsttime:378self.eqtext = []379else:380for artist in self.eqtext:381#print(artist)382artist.remove()383self.eqtext = []384if self.logistic:385_, equation = map_feature(self.X[:, 0], self.X[:, 1], self.degree)386string = 'f_{wb} = sigmoid('387else:388_, equation = map_one_feature(self.X, self.degree)389string = 'f_{wb} = ('390bz = 10391seq = equation.split('+')392blks = math.ceil(len(seq)/bz)393for i in range(blks):394if i == 0:395string = string + '+'.join(seq[bz*i:bz*i+bz])396else:397string = '+'.join(seq[bz*i:bz*i+bz])398string = string + ')' if i == blks-1 else string + '+'399ei = self.ax[1].text(0.01,(0.75-i*0.25), f"${string}$",fontsize=9,400transform = self.ax[1].transAxes, ma='left', va='top' )401self.eqtext.append(ei)402self.fig.canvas.draw()403404405