Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jxareas
GitHub Repository: jxareas/Machine-Learning-Notebooks
Path: blob/master/1_Supervised_Machine_Learning/Week 3. Classification/plt_overfit.py
2826 views
1
"""
2
plot_overfit
3
class and assocaited routines that plot an interactive example of overfitting and its solutions
4
"""
5
import math
6
from ipywidgets import Output
7
from matplotlib.gridspec import GridSpec
8
from matplotlib.widgets import Button, CheckButtons
9
from sklearn.linear_model import LogisticRegression, Ridge
10
from lab_utils_common import np, plt, dlc, predict_logistic, plot_data, zscore_normalize_features
11
12
def map_one_feature(X1, degree):
13
"""
14
Feature mapping function to polynomial features
15
"""
16
X1 = np.atleast_1d(X1)
17
out = []
18
string = ""
19
k = 0
20
for i in range(1, degree+1):
21
out.append((X1**i))
22
string = string + f"w_{{{k}}}{munge('x_0',i)} + "
23
k += 1
24
string = string + ' b' #add b to text equation, not to data
25
return np.stack(out, axis=1), string
26
27
28
def map_feature(X1, X2, degree):
29
"""
30
Feature mapping function to polynomial features
31
"""
32
X1 = np.atleast_1d(X1)
33
X2 = np.atleast_1d(X2)
34
35
out = []
36
string = ""
37
k = 0
38
for i in range(1, degree+1):
39
for j in range(i + 1):
40
out.append((X1**(i-j) * (X2**j)))
41
string = string + f"w_{{{k}}}{munge('x_0',i-j)}{munge('x_1',j)} + "
42
k += 1
43
#print(string + 'b')
44
return np.stack(out, axis=1), string + ' b'
45
46
def munge(base, exp):
47
if exp == 0:
48
return ''
49
if exp == 1:
50
return base
51
return base + f'^{{{exp}}}'
52
53
def plot_decision_boundary(ax, x0r,x1r, predict, w, b, scaler = False, mu=None, sigma=None, degree=None):
54
"""
55
Plots a decision boundary
56
Args:
57
x0r : (array_like Shape (1,1)) range (min, max) of x0
58
x1r : (array_like Shape (1,1)) range (min, max) of x1
59
predict : function to predict z values
60
scalar : (boolean) scale data or not
61
"""
62
63
h = .01 # step size in the mesh
64
# create a mesh to plot in
65
xx, yy = np.meshgrid(np.arange(x0r[0], x0r[1], h),
66
np.arange(x1r[0], x1r[1], h))
67
68
# Plot the decision boundary. For that, we will assign a color to each
69
# point in the mesh [x_min, m_max]x[y_min, y_max].
70
points = np.c_[xx.ravel(), yy.ravel()]
71
Xm,_ = map_feature(points[:, 0], points[:, 1],degree)
72
if scaler:
73
Xm = (Xm - mu)/sigma
74
Z = predict(Xm, w, b)
75
76
# Put the result into a color plot
77
Z = Z.reshape(xx.shape)
78
contour = ax.contour(xx, yy, Z, levels = [0.5], colors='g')
79
return contour
80
81
# use this to test the above routine
82
def plot_decision_boundary_sklearn(x0r, x1r, predict, degree, scaler = False):
83
"""
84
Plots a decision boundary
85
Args:
86
x0r : (array_like Shape (1,1)) range (min, max) of x0
87
x1r : (array_like Shape (1,1)) range (min, max) of x1
88
degree: (int) degree of polynomial
89
predict : function to predict z values
90
scaler : not sure
91
"""
92
93
h = .01 # step size in the mesh
94
# create a mesh to plot in
95
xx, yy = np.meshgrid(np.arange(x0r[0], x0r[1], h),
96
np.arange(x1r[0], x1r[1], h))
97
98
# Plot the decision boundary. For that, we will assign a color to each
99
# point in the mesh [x_min, m_max]x[y_min, y_max].
100
points = np.c_[xx.ravel(), yy.ravel()]
101
Xm = map_feature(points[:, 0], points[:, 1],degree)
102
if scaler:
103
Xm = scaler.transform(Xm)
104
Z = predict(Xm)
105
106
# Put the result into a color plot
107
Z = Z.reshape(xx.shape)
108
plt.contour(xx, yy, Z, colors='g')
109
#plot_data(X_train,y_train)
110
111
#for debug, uncomment the #@output statments below for routines you want to get error output from
112
# In the notebook that will call these routines, import `output`
113
# from plt_overfit import overfit_example, output
114
# then, in a cell where the error messages will be the output of..
115
#display(output)
116
117
output = Output() # sends hidden error messages to display when using widgets
118
119
class button_manager:
120
''' Handles some missing features of matplotlib check buttons
121
on init:
122
creates button, links to button_click routine,
123
calls call_on_click with active index and firsttime=True
124
on click:
125
maintains single button on state, calls call_on_click
126
'''
127
128
@output.capture() # debug
129
def __init__(self,fig, dim, labels, init, call_on_click):
130
'''
131
dim: (list) [leftbottom_x,bottom_y,width,height]
132
labels: (list) for example ['1','2','3','4','5','6']
133
init: (list) for example [True, False, False, False, False, False]
134
'''
135
self.fig = fig
136
self.ax = plt.axes(dim) #lx,by,w,h
137
self.init_state = init
138
self.call_on_click = call_on_click
139
self.button = CheckButtons(self.ax,labels,init)
140
self.button.on_clicked(self.button_click)
141
self.status = self.button.get_status()
142
self.call_on_click(self.status.index(True),firsttime=True)
143
144
@output.capture() # debug
145
def reinit(self):
146
self.status = self.init_state
147
self.button.set_active(self.status.index(True)) #turn off old, will trigger update and set to status
148
149
@output.capture() # debug
150
def button_click(self, event):
151
''' maintains one-on state. If on-button is clicked, will process correctly '''
152
#new_status = self.button.get_status()
153
#new = [self.status[i] ^ new_status[i] for i in range(len(self.status))]
154
#newidx = new.index(True)
155
self.button.eventson = False
156
self.button.set_active(self.status.index(True)) #turn off old or reenable if same
157
self.button.eventson = True
158
self.status = self.button.get_status()
159
self.call_on_click(self.status.index(True))
160
161
class overfit_example():
162
""" plot overfit example """
163
# pylint: disable=too-many-instance-attributes
164
# pylint: disable=too-many-locals
165
# pylint: disable=missing-function-docstring
166
# pylint: disable=attribute-defined-outside-init
167
def __init__(self, regularize=False):
168
self.regularize=regularize
169
self.lambda_=0
170
fig = plt.figure( figsize=(8,6))
171
fig.canvas.toolbar_visible = False
172
fig.canvas.header_visible = False
173
fig.canvas.footer_visible = False
174
fig.set_facecolor('#ffffff') #white
175
gs = GridSpec(5, 3, figure=fig)
176
ax0 = fig.add_subplot(gs[0:3, :])
177
ax1 = fig.add_subplot(gs[-2, :])
178
ax2 = fig.add_subplot(gs[-1, :])
179
ax1.set_axis_off()
180
ax2.set_axis_off()
181
self.ax = [ax0,ax1,ax2]
182
self.fig = fig
183
184
self.axfitdata = plt.axes([0.26,0.124,0.12,0.1 ]) #lx,by,w,h
185
self.bfitdata = Button(self.axfitdata , 'fit data', color=dlc['dlblue'])
186
self.bfitdata.label.set_fontsize(12)
187
self.bfitdata.on_clicked(self.fitdata_clicked)
188
189
#clear data is a future enhancement
190
#self.axclrdata = plt.axes([0.26,0.06,0.12,0.05 ]) #lx,by,w,h
191
#self.bclrdata = Button(self.axclrdata , 'clear data', color='white')
192
#self.bclrdata.label.set_fontsize(12)
193
#self.bclrdata.on_clicked(self.clrdata_clicked)
194
195
self.cid = fig.canvas.mpl_connect('button_press_event', self.add_data)
196
197
self.typebut = button_manager(fig, [0.4, 0.07,0.15,0.15], ["Regression", "Categorical"],
198
[False,True], self.toggle_type)
199
200
self.fig.text(0.1, 0.02+0.21, "Degree", fontsize=12)
201
self.degrbut = button_manager(fig,[0.1,0.02,0.15,0.2 ], ['1','2','3','4','5','6'],
202
[True, False, False, False, False, False], self.update_equation)
203
if self.regularize:
204
self.fig.text(0.6, 0.02+0.21, r"lambda($\lambda$)", fontsize=12)
205
self.lambut = button_manager(fig,[0.6,0.02,0.15,0.2 ], ['0.0','0.2','0.4','0.6','0.8','1'],
206
[True, False, False, False, False, False], self.updt_lambda)
207
208
#self.regbut = button_manager(fig, [0.8, 0.08,0.24,0.15], ["Regularize"],
209
# [False], self.toggle_reg)
210
#self.logistic_data()
211
212
def updt_lambda(self, idx, firsttime=False):
213
# pylint: disable=unused-argument
214
self.lambda_ = idx * 0.2
215
216
def toggle_type(self, idx, firsttime=False):
217
self.logistic = idx==1
218
self.ax[0].clear()
219
if self.logistic:
220
self.logistic_data()
221
else:
222
self.linear_data()
223
if not firsttime:
224
self.degrbut.reinit()
225
226
@output.capture() # debug
227
def logistic_data(self,redraw=False):
228
if not redraw:
229
m = 50
230
n = 2
231
np.random.seed(2)
232
X_train = 2*(np.random.rand(m,n)-[0.5,0.5])
233
y_train = X_train[:,1]+0.5 > X_train[:,0]**2 + 0.5*np.random.rand(m) #quadratic + random
234
y_train = y_train + 0 #convert from boolean to integer
235
self.X = X_train
236
self.y = y_train
237
self.x_ideal = np.sort(X_train[:,0])
238
self.y_ideal = self.x_ideal**2
239
240
241
self.ax[0].plot(self.x_ideal, self.y_ideal, "--", color = "orangered", label="ideal", lw=1)
242
plot_data(self.X, self.y, self.ax[0], s=10, loc='lower right')
243
self.ax[0].set_title("OverFitting Example: Categorical data set with noise")
244
self.ax[0].text(0.5,0.93, "Click on plot to add data. Hold [Shift] for blue(y=0) data.",
245
fontsize=12, ha='center',transform=self.ax[0].transAxes, color=dlc["dlblue"])
246
self.ax[0].set_xlabel(r"$x_0$")
247
self.ax[0].set_ylabel(r"$x_1$")
248
249
def linear_data(self,redraw=False):
250
if not redraw:
251
m = 30
252
c = 0
253
x_train = np.arange(0,m,1)
254
np.random.seed(1)
255
y_ideal = x_train**2 + c
256
y_train = y_ideal + 0.7 * y_ideal*(np.random.sample((m,))-0.5)
257
self.x_ideal = x_train #for redraw when new data included in X
258
self.X = x_train
259
self.y = y_train
260
self.y_ideal = y_ideal
261
else:
262
self.ax[0].set_xlim(self.xlim)
263
self.ax[0].set_ylim(self.ylim)
264
265
self.ax[0].scatter(self.X,self.y, label="y")
266
self.ax[0].plot(self.x_ideal, self.y_ideal, "--", color = "orangered", label="y_ideal", lw=1)
267
self.ax[0].set_title("OverFitting Example: Regression Data Set (quadratic with noise)",fontsize = 14)
268
self.ax[0].set_xlabel("x")
269
self.ax[0].set_ylabel("y")
270
self.ax0ledgend = self.ax[0].legend(loc='lower right')
271
self.ax[0].text(0.5,0.93, "Click on plot to add data",
272
fontsize=12, ha='center',transform=self.ax[0].transAxes, color=dlc["dlblue"])
273
if not redraw:
274
self.xlim = self.ax[0].get_xlim()
275
self.ylim = self.ax[0].get_ylim()
276
277
278
@output.capture() # debug
279
def add_data(self, event):
280
if self.logistic:
281
self.add_data_logistic(event)
282
else:
283
self.add_data_linear(event)
284
285
@output.capture() # debug
286
def add_data_logistic(self, event):
287
if event.inaxes == self.ax[0]:
288
x0_coord = event.xdata
289
x1_coord = event.ydata
290
291
if event.key is None: #shift not pressed
292
self.ax[0].scatter(x0_coord, x1_coord, marker='x', s=10, c = 'red', label="y=1")
293
self.y = np.append(self.y,1)
294
else:
295
self.ax[0].scatter(x0_coord, x1_coord, marker='o', s=10, label="y=0", facecolors='none',
296
edgecolors=dlc['dlblue'],lw=3)
297
self.y = np.append(self.y,0)
298
self.X = np.append(self.X,np.array([[x0_coord, x1_coord]]),axis=0)
299
self.fig.canvas.draw()
300
301
def add_data_linear(self, event):
302
if event.inaxes == self.ax[0]:
303
x_coord = event.xdata
304
y_coord = event.ydata
305
306
self.ax[0].scatter(x_coord, y_coord, marker='o', s=10, facecolors='none',
307
edgecolors=dlc['dlblue'],lw=3)
308
self.y = np.append(self.y,y_coord)
309
self.X = np.append(self.X,x_coord)
310
self.fig.canvas.draw()
311
312
#@output.capture() # debug
313
#def clrdata_clicked(self,event):
314
# if self.logistic == True:
315
# self.X = np.
316
# else:
317
# self.linear_regression()
318
319
320
@output.capture() # debug
321
def fitdata_clicked(self,event):
322
if self.logistic:
323
self.logistic_regression()
324
else:
325
self.linear_regression()
326
327
def linear_regression(self):
328
self.ax[0].clear()
329
self.fig.canvas.draw()
330
331
# create and fit the model using our mapped_X feature set.
332
self.X_mapped, _ = map_one_feature(self.X, self.degree)
333
self.X_mapped_scaled, self.X_mu, self.X_sigma = zscore_normalize_features(self.X_mapped)
334
335
#linear_model = LinearRegression()
336
linear_model = Ridge(alpha=self.lambda_, normalize=True, max_iter=10000)
337
linear_model.fit(self.X_mapped_scaled, self.y )
338
self.w = linear_model.coef_.reshape(-1,)
339
self.b = linear_model.intercept_
340
x = np.linspace(*self.xlim,30) #plot line idependent of data which gets disordered
341
xm, _ = map_one_feature(x, self.degree)
342
xms = (xm - self.X_mu)/ self.X_sigma
343
y_pred = linear_model.predict(xms)
344
345
#self.fig.canvas.draw()
346
self.linear_data(redraw=True)
347
self.ax0yfit = self.ax[0].plot(x, y_pred, color = "blue", label="y_fit")
348
self.ax0ledgend = self.ax[0].legend(loc='lower right')
349
self.fig.canvas.draw()
350
351
def logistic_regression(self):
352
self.ax[0].clear()
353
self.fig.canvas.draw()
354
355
# create and fit the model using our mapped_X feature set.
356
self.X_mapped, _ = map_feature(self.X[:, 0], self.X[:, 1], self.degree)
357
self.X_mapped_scaled, self.X_mu, self.X_sigma = zscore_normalize_features(self.X_mapped)
358
if not self.regularize or self.lambda_ == 0:
359
lr = LogisticRegression(penalty='none', max_iter=10000)
360
else:
361
C = 1/self.lambda_
362
lr = LogisticRegression(C=C, max_iter=10000)
363
364
lr.fit(self.X_mapped_scaled,self.y)
365
#print(lr.score(self.X_mapped_scaled, self.y))
366
self.w = lr.coef_.reshape(-1,)
367
self.b = lr.intercept_
368
#print(self.w, self.b)
369
self.logistic_data(redraw=True)
370
self.contour = plot_decision_boundary(self.ax[0],[-1,1],[-1,1], predict_logistic, self.w, self.b,
371
scaler=True, mu=self.X_mu, sigma=self.X_sigma, degree=self.degree )
372
self.fig.canvas.draw()
373
374
@output.capture() # debug
375
def update_equation(self, idx, firsttime=False):
376
#print(f"Update equation, index = {idx}, firsttime={firsttime}")
377
self.degree = idx+1
378
if firsttime:
379
self.eqtext = []
380
else:
381
for artist in self.eqtext:
382
#print(artist)
383
artist.remove()
384
self.eqtext = []
385
if self.logistic:
386
_, equation = map_feature(self.X[:, 0], self.X[:, 1], self.degree)
387
string = 'f_{wb} = sigmoid('
388
else:
389
_, equation = map_one_feature(self.X, self.degree)
390
string = 'f_{wb} = ('
391
bz = 10
392
seq = equation.split('+')
393
blks = math.ceil(len(seq)/bz)
394
for i in range(blks):
395
if i == 0:
396
string = string + '+'.join(seq[bz*i:bz*i+bz])
397
else:
398
string = '+'.join(seq[bz*i:bz*i+bz])
399
string = string + ')' if i == blks-1 else string + '+'
400
ei = self.ax[1].text(0.01,(0.75-i*0.25), f"${string}$",fontsize=9,
401
transform = self.ax[1].transAxes, ma='left', va='top' )
402
self.eqtext.append(ei)
403
self.fig.canvas.draw()
404
405