Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jxareas
GitHub Repository: jxareas/Machine-Learning-Notebooks
Path: blob/master/1_Supervised_Machine_Learning/Week 3. Classification/plt_one_addpt_onclick.py
2826 views
1
import time
2
import copy
3
from ipywidgets import Output
4
from matplotlib.widgets import Button, CheckButtons
5
from matplotlib.patches import FancyArrowPatch
6
from lab_utils_common import np, plt, dlblue, dlorange, sigmoid, dldarkred, gradient_descent
7
8
# for debug
9
#output = Output() # sends hidden error messages to display when using widgets
10
#display(output)
11
12
class plt_one_addpt_onclick:
13
""" class to run one interactive plot """
14
def __init__(self, x, y, w, b, logistic=True):
15
self.logistic=logistic
16
pos = y == 1
17
neg = y == 0
18
19
fig,ax = plt.subplots(1,1,figsize=(8,4))
20
fig.canvas.toolbar_visible = False
21
fig.canvas.header_visible = False
22
fig.canvas.footer_visible = False
23
24
plt.subplots_adjust(bottom=0.25)
25
ax.scatter(x[pos], y[pos], marker='x', s=80, c = 'red', label="malignant")
26
ax.scatter(x[neg], y[neg], marker='o', s=100, label="benign", facecolors='none', edgecolors=dlblue,lw=3)
27
ax.set_ylim(-0.05,1.1)
28
xlim = ax.get_xlim()
29
ax.set_xlim(xlim[0],xlim[1]*2)
30
ax.set_ylabel('y')
31
ax.set_xlabel('Tumor Size')
32
self.alegend = ax.legend(loc='lower right')
33
if self.logistic:
34
ax.set_title("Example of Logistic Regression on Categorical Data")
35
else:
36
ax.set_title("Example of Linear Regression on Categorical Data")
37
38
ax.text(0.65,0.8,"[Click to add data points]", size=10, transform=ax.transAxes)
39
40
axcalc = plt.axes([0.1, 0.05, 0.38, 0.075]) #l,b,w,h
41
axthresh = plt.axes([0.5, 0.05, 0.38, 0.075]) #l,b,w,h
42
self.tlist = []
43
44
self.fig = fig
45
self.ax = [ax,axcalc,axthresh]
46
self.x = x
47
self.y = y
48
self.w = copy.deepcopy(w)
49
self.b = b
50
f_wb = np.matmul(self.x.reshape(-1,1), self.w) + self.b
51
if self.logistic:
52
self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue)
53
self.bline = self.ax[0].plot(self.x, f_wb, color=dlorange,lw=1)
54
else:
55
self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue)
56
57
self.cid = fig.canvas.mpl_connect('button_press_event', self.add_data)
58
if self.logistic:
59
self.bcalc = Button(axcalc, 'Run Logistic Regression (click)', color=dlblue)
60
self.bcalc.on_clicked(self.calc_logistic)
61
else:
62
self.bcalc = Button(axcalc, 'Run Linear Regression (click)', color=dlblue)
63
self.bcalc.on_clicked(self.calc_linear)
64
self.bthresh = CheckButtons(axthresh, ('Toggle 0.5 threshold (after regression)',))
65
self.bthresh.on_clicked(self.thresh)
66
self.resize_sq(self.bthresh)
67
68
# @output.capture() # debug
69
def add_data(self, event):
70
#self.ax[0].text(0.1,0.1, f"in onclick")
71
if event.inaxes == self.ax[0]:
72
x_coord = event.xdata
73
y_coord = event.ydata
74
75
if y_coord > 0.5:
76
self.ax[0].scatter(x_coord, 1, marker='x', s=80, c = 'red' )
77
self.y = np.append(self.y,1)
78
else:
79
self.ax[0].scatter(x_coord, 0, marker='o', s=100, facecolors='none', edgecolors=dlblue,lw=3)
80
self.y = np.append(self.y,0)
81
self.x = np.append(self.x,x_coord)
82
self.fig.canvas.draw()
83
84
# @output.capture() # debug
85
def calc_linear(self, event):
86
if self.bthresh.get_status()[0]:
87
self.remove_thresh()
88
for it in [1,1,1,1,1,2,4,8,16,32,64,128,256]:
89
self.w, self.b, _ = gradient_descent(self.x.reshape(-1,1), self.y.reshape(-1,1),
90
self.w.reshape(-1,1), self.b, 0.01, it,
91
logistic=False, lambda_=0, verbose=False)
92
self.aline[0].remove()
93
self.alegend.remove()
94
y_hat = np.matmul(self.x.reshape(-1,1), self.w) + self.b
95
self.aline = self.ax[0].plot(self.x, y_hat, color=dlblue,
96
label=f"y = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")
97
self.alegend = self.ax[0].legend(loc='lower right')
98
time.sleep(0.3)
99
self.fig.canvas.draw()
100
if self.bthresh.get_status()[0]:
101
self.draw_thresh()
102
self.fig.canvas.draw()
103
104
def calc_logistic(self, event):
105
if self.bthresh.get_status()[0]:
106
self.remove_thresh()
107
for it in [1, 8,16,32,64,128,256,512,1024,2048,4096]:
108
self.w, self.b, _ = gradient_descent(self.x.reshape(-1,1), self.y.reshape(-1,1),
109
self.w.reshape(-1,1), self.b, 0.1, it,
110
logistic=True, lambda_=0, verbose=False)
111
self.aline[0].remove()
112
self.bline[0].remove()
113
self.alegend.remove()
114
xlim = self.ax[0].get_xlim()
115
x_hat = np.linspace(*xlim, 30)
116
y_hat = sigmoid(np.matmul(x_hat.reshape(-1,1), self.w) + self.b)
117
self.aline = self.ax[0].plot(x_hat, y_hat, color=dlblue,
118
label="y = sigmoid(z)")
119
f_wb = np.matmul(x_hat.reshape(-1,1), self.w) + self.b
120
self.bline = self.ax[0].plot(x_hat, f_wb, color=dlorange, lw=1,
121
label=f"z = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")
122
self.alegend = self.ax[0].legend(loc='lower right')
123
time.sleep(0.3)
124
self.fig.canvas.draw()
125
if self.bthresh.get_status()[0]:
126
self.draw_thresh()
127
self.fig.canvas.draw()
128
129
130
def thresh(self, event):
131
if self.bthresh.get_status()[0]:
132
#plt.figtext(0,0, f"in thresh {self.bthresh.get_status()}")
133
self.draw_thresh()
134
else:
135
#plt.figtext(0,0.3, f"in thresh {self.bthresh.get_status()}")
136
self.remove_thresh()
137
138
def draw_thresh(self):
139
ws = np.squeeze(self.w)
140
xp5 = -self.b/ws if self.logistic else (0.5 - self.b) / ws
141
ylim = self.ax[0].get_ylim()
142
xlim = self.ax[0].get_xlim()
143
a = self.ax[0].fill_between([xlim[0], xp5], [ylim[1], ylim[1]], alpha=0.2, color=dlblue)
144
b = self.ax[0].fill_between([xp5, xlim[1]], [ylim[1], ylim[1]], alpha=0.2, color=dldarkred)
145
c = self.ax[0].annotate("Malignant", xy= [xp5,0.5], xycoords='data',
146
xytext=[30,5],textcoords='offset points')
147
d = FancyArrowPatch(
148
posA=(xp5, 0.5), posB=(xp5+1.5, 0.5), color=dldarkred,
149
arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',
150
)
151
self.ax[0].add_artist(d)
152
153
e = self.ax[0].annotate("Benign", xy= [xp5,0.5], xycoords='data',
154
xytext=[-70,5],textcoords='offset points', ha='left')
155
f = FancyArrowPatch(
156
posA=(xp5, 0.5), posB=(xp5-1.5, 0.5), color=dlblue,
157
arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',
158
)
159
self.ax[0].add_artist(f)
160
self.tlist = [a,b,c,d,e,f]
161
162
self.fig.canvas.draw()
163
164
def remove_thresh(self):
165
#plt.figtext(0.5,0.0, f"rem thresh {self.bthresh.get_status()}")
166
for artist in self.tlist:
167
artist.remove()
168
self.fig.canvas.draw()
169
170
def resize_sq(self, bcid):
171
""" resizes the check box """
172
#future reference
173
#print(f"width : {bcid.rectangles[0].get_width()}")
174
#print(f"height : {bcid.rectangles[0].get_height()}")
175
#print(f"xy : {bcid.rectangles[0].get_xy()}")
176
#print(f"bb : {bcid.rectangles[0].get_bbox()}")
177
#print(f"points : {bcid.rectangles[0].get_bbox().get_points()}") #[[xmin,ymin],[xmax,ymax]]
178
179
h = bcid.rectangles[0].get_height()
180
bcid.rectangles[0].set_height(3*h)
181
182
ymax = bcid.rectangles[0].get_bbox().y1
183
ymin = bcid.rectangles[0].get_bbox().y0
184
185
bcid.lines[0][0].set_ydata([ymax,ymin])
186
bcid.lines[0][1].set_ydata([ymin,ymax])
187
188