Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
"""
2
==========
3
Libsvm GUI
4
==========
5
6
A simple graphical frontend for Libsvm mainly intended for didactic
7
purposes. You can create data points by point and click and visualize
8
the decision region induced by different kernels and parameter settings.
9
10
To create positive examples click the left mouse button; to create
11
negative examples click the right button.
12
13
If all examples are from the same class, it uses a one-class SVM.
14
15
"""
16
from __future__ import division, print_function
17
18
print(__doc__)
19
20
# Author: Peter Prettenhoer <[email protected]>
21
#
22
# License: BSD 3 clause
23
24
import matplotlib
25
matplotlib.use('TkAgg')
26
27
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
28
from matplotlib.backends.backend_tkagg import NavigationToolbar2TkAgg
29
from matplotlib.figure import Figure
30
from matplotlib.contour import ContourSet
31
32
import Tkinter as Tk
33
import sys
34
import numpy as np
35
36
from sklearn import svm
37
from sklearn.datasets import dump_svmlight_file
38
from sklearn.externals.six.moves import xrange
39
40
y_min, y_max = -50, 50
41
x_min, x_max = -50, 50
42
43
44
class Model(object):
45
"""The Model which hold the data. It implements the
46
observable in the observer pattern and notifies the
47
registered observers on change event.
48
"""
49
50
def __init__(self):
51
self.observers = []
52
self.surface = None
53
self.data = []
54
self.cls = None
55
self.surface_type = 0
56
57
def changed(self, event):
58
"""Notify the observers. """
59
for observer in self.observers:
60
observer.update(event, self)
61
62
def add_observer(self, observer):
63
"""Register an observer. """
64
self.observers.append(observer)
65
66
def set_surface(self, surface):
67
self.surface = surface
68
69
def dump_svmlight_file(self, file):
70
data = np.array(self.data)
71
X = data[:, 0:2]
72
y = data[:, 2]
73
dump_svmlight_file(X, y, file)
74
75
76
class Controller(object):
77
def __init__(self, model):
78
self.model = model
79
self.kernel = Tk.IntVar()
80
self.surface_type = Tk.IntVar()
81
# Whether or not a model has been fitted
82
self.fitted = False
83
84
def fit(self):
85
print("fit the model")
86
train = np.array(self.model.data)
87
X = train[:, 0:2]
88
y = train[:, 2]
89
90
C = float(self.complexity.get())
91
gamma = float(self.gamma.get())
92
coef0 = float(self.coef0.get())
93
degree = int(self.degree.get())
94
kernel_map = {0: "linear", 1: "rbf", 2: "poly"}
95
if len(np.unique(y)) == 1:
96
clf = svm.OneClassSVM(kernel=kernel_map[self.kernel.get()],
97
gamma=gamma, coef0=coef0, degree=degree)
98
clf.fit(X)
99
else:
100
clf = svm.SVC(kernel=kernel_map[self.kernel.get()], C=C,
101
gamma=gamma, coef0=coef0, degree=degree)
102
clf.fit(X, y)
103
if hasattr(clf, 'score'):
104
print("Accuracy:", clf.score(X, y) * 100)
105
X1, X2, Z = self.decision_surface(clf)
106
self.model.clf = clf
107
self.model.set_surface((X1, X2, Z))
108
self.model.surface_type = self.surface_type.get()
109
self.fitted = True
110
self.model.changed("surface")
111
112
def decision_surface(self, cls):
113
delta = 1
114
x = np.arange(x_min, x_max + delta, delta)
115
y = np.arange(y_min, y_max + delta, delta)
116
X1, X2 = np.meshgrid(x, y)
117
Z = cls.decision_function(np.c_[X1.ravel(), X2.ravel()])
118
Z = Z.reshape(X1.shape)
119
return X1, X2, Z
120
121
def clear_data(self):
122
self.model.data = []
123
self.fitted = False
124
self.model.changed("clear")
125
126
def add_example(self, x, y, label):
127
self.model.data.append((x, y, label))
128
self.model.changed("example_added")
129
130
# update decision surface if already fitted.
131
self.refit()
132
133
def refit(self):
134
"""Refit the model if already fitted. """
135
if self.fitted:
136
self.fit()
137
138
139
class View(object):
140
"""Test docstring. """
141
def __init__(self, root, controller):
142
f = Figure()
143
ax = f.add_subplot(111)
144
ax.set_xticks([])
145
ax.set_yticks([])
146
ax.set_xlim((x_min, x_max))
147
ax.set_ylim((y_min, y_max))
148
canvas = FigureCanvasTkAgg(f, master=root)
149
canvas.show()
150
canvas.get_tk_widget().pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)
151
canvas._tkcanvas.pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)
152
canvas.mpl_connect('key_press_event', self.onkeypress)
153
canvas.mpl_connect('key_release_event', self.onkeyrelease)
154
canvas.mpl_connect('button_press_event', self.onclick)
155
toolbar = NavigationToolbar2TkAgg(canvas, root)
156
toolbar.update()
157
self.shift_down = False
158
self.controllbar = ControllBar(root, controller)
159
self.f = f
160
self.ax = ax
161
self.canvas = canvas
162
self.controller = controller
163
self.contours = []
164
self.c_labels = None
165
self.plot_kernels()
166
167
def plot_kernels(self):
168
self.ax.text(-50, -60, "Linear: $u^T v$")
169
self.ax.text(-20, -60, "RBF: $\exp (-\gamma \| u-v \|^2)$")
170
self.ax.text(10, -60, "Poly: $(\gamma \, u^T v + r)^d$")
171
172
def onkeypress(self, event):
173
if event.key == "shift":
174
self.shift_down = True
175
176
def onkeyrelease(self, event):
177
if event.key == "shift":
178
self.shift_down = False
179
180
def onclick(self, event):
181
if event.xdata and event.ydata:
182
if self.shift_down or event.button == 3:
183
self.controller.add_example(event.xdata, event.ydata, -1)
184
elif event.button == 1:
185
self.controller.add_example(event.xdata, event.ydata, 1)
186
187
def update_example(self, model, idx):
188
x, y, l = model.data[idx]
189
if l == 1:
190
color = 'w'
191
elif l == -1:
192
color = 'k'
193
self.ax.plot([x], [y], "%so" % color, scalex=0.0, scaley=0.0)
194
195
def update(self, event, model):
196
if event == "examples_loaded":
197
for i in xrange(len(model.data)):
198
self.update_example(model, i)
199
200
if event == "example_added":
201
self.update_example(model, -1)
202
203
if event == "clear":
204
self.ax.clear()
205
self.ax.set_xticks([])
206
self.ax.set_yticks([])
207
self.contours = []
208
self.c_labels = None
209
self.plot_kernels()
210
211
if event == "surface":
212
self.remove_surface()
213
self.plot_support_vectors(model.clf.support_vectors_)
214
self.plot_decision_surface(model.surface, model.surface_type)
215
216
self.canvas.draw()
217
218
def remove_surface(self):
219
"""Remove old decision surface."""
220
if len(self.contours) > 0:
221
for contour in self.contours:
222
if isinstance(contour, ContourSet):
223
for lineset in contour.collections:
224
lineset.remove()
225
else:
226
contour.remove()
227
self.contours = []
228
229
def plot_support_vectors(self, support_vectors):
230
"""Plot the support vectors by placing circles over the
231
corresponding data points and adds the circle collection
232
to the contours list."""
233
cs = self.ax.scatter(support_vectors[:, 0], support_vectors[:, 1],
234
s=80, edgecolors="k", facecolors="none")
235
self.contours.append(cs)
236
237
def plot_decision_surface(self, surface, type):
238
X1, X2, Z = surface
239
if type == 0:
240
levels = [-1.0, 0.0, 1.0]
241
linestyles = ['dashed', 'solid', 'dashed']
242
colors = 'k'
243
self.contours.append(self.ax.contour(X1, X2, Z, levels,
244
colors=colors,
245
linestyles=linestyles))
246
elif type == 1:
247
self.contours.append(self.ax.contourf(X1, X2, Z, 10,
248
cmap=matplotlib.cm.bone,
249
origin='lower', alpha=0.85))
250
self.contours.append(self.ax.contour(X1, X2, Z, [0.0], colors='k',
251
linestyles=['solid']))
252
else:
253
raise ValueError("surface type unknown")
254
255
256
class ControllBar(object):
257
def __init__(self, root, controller):
258
fm = Tk.Frame(root)
259
kernel_group = Tk.Frame(fm)
260
Tk.Radiobutton(kernel_group, text="Linear", variable=controller.kernel,
261
value=0, command=controller.refit).pack(anchor=Tk.W)
262
Tk.Radiobutton(kernel_group, text="RBF", variable=controller.kernel,
263
value=1, command=controller.refit).pack(anchor=Tk.W)
264
Tk.Radiobutton(kernel_group, text="Poly", variable=controller.kernel,
265
value=2, command=controller.refit).pack(anchor=Tk.W)
266
kernel_group.pack(side=Tk.LEFT)
267
268
valbox = Tk.Frame(fm)
269
controller.complexity = Tk.StringVar()
270
controller.complexity.set("1.0")
271
c = Tk.Frame(valbox)
272
Tk.Label(c, text="C:", anchor="e", width=7).pack(side=Tk.LEFT)
273
Tk.Entry(c, width=6, textvariable=controller.complexity).pack(
274
side=Tk.LEFT)
275
c.pack()
276
277
controller.gamma = Tk.StringVar()
278
controller.gamma.set("0.01")
279
g = Tk.Frame(valbox)
280
Tk.Label(g, text="gamma:", anchor="e", width=7).pack(side=Tk.LEFT)
281
Tk.Entry(g, width=6, textvariable=controller.gamma).pack(side=Tk.LEFT)
282
g.pack()
283
284
controller.degree = Tk.StringVar()
285
controller.degree.set("3")
286
d = Tk.Frame(valbox)
287
Tk.Label(d, text="degree:", anchor="e", width=7).pack(side=Tk.LEFT)
288
Tk.Entry(d, width=6, textvariable=controller.degree).pack(side=Tk.LEFT)
289
d.pack()
290
291
controller.coef0 = Tk.StringVar()
292
controller.coef0.set("0")
293
r = Tk.Frame(valbox)
294
Tk.Label(r, text="coef0:", anchor="e", width=7).pack(side=Tk.LEFT)
295
Tk.Entry(r, width=6, textvariable=controller.coef0).pack(side=Tk.LEFT)
296
r.pack()
297
valbox.pack(side=Tk.LEFT)
298
299
cmap_group = Tk.Frame(fm)
300
Tk.Radiobutton(cmap_group, text="Hyperplanes",
301
variable=controller.surface_type, value=0,
302
command=controller.refit).pack(anchor=Tk.W)
303
Tk.Radiobutton(cmap_group, text="Surface",
304
variable=controller.surface_type, value=1,
305
command=controller.refit).pack(anchor=Tk.W)
306
307
cmap_group.pack(side=Tk.LEFT)
308
309
train_button = Tk.Button(fm, text='Fit', width=5,
310
command=controller.fit)
311
train_button.pack()
312
fm.pack(side=Tk.LEFT)
313
Tk.Button(fm, text='Clear', width=5,
314
command=controller.clear_data).pack(side=Tk.LEFT)
315
316
317
def get_parser():
318
from optparse import OptionParser
319
op = OptionParser()
320
op.add_option("--output",
321
action="store", type="str", dest="output",
322
help="Path where to dump data.")
323
return op
324
325
326
def main(argv):
327
op = get_parser()
328
opts, args = op.parse_args(argv[1:])
329
root = Tk.Tk()
330
model = Model()
331
controller = Controller(model)
332
root.wm_title("Scikit-learn Libsvm GUI")
333
view = View(root, controller)
334
model.add_observer(view)
335
Tk.mainloop()
336
337
if opts.output:
338
model.dump_svmlight_file(opts.output)
339
340
if __name__ == "__main__":
341
main(sys.argv)
342
343