📚 The CoCalc Library - books, templates and other resources
License: OTHER
"""1==========2Libsvm GUI3==========45A simple graphical frontend for Libsvm mainly intended for didactic6purposes. You can create data points by point and click and visualize7the decision region induced by different kernels and parameter settings.89To create positive examples click the left mouse button; to create10negative examples click the right button.1112If all examples are from the same class, it uses a one-class SVM.1314"""15from __future__ import division, print_function1617print(__doc__)1819# Author: Peter Prettenhoer <[email protected]>20#21# License: BSD 3 clause2223import matplotlib24matplotlib.use('TkAgg')2526from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg27from matplotlib.backends.backend_tkagg import NavigationToolbar2TkAgg28from matplotlib.figure import Figure29from matplotlib.contour import ContourSet3031import Tkinter as Tk32import sys33import numpy as np3435from sklearn import svm36from sklearn.datasets import dump_svmlight_file37from sklearn.externals.six.moves import xrange3839y_min, y_max = -50, 5040x_min, x_max = -50, 50414243class Model(object):44"""The Model which hold the data. It implements the45observable in the observer pattern and notifies the46registered observers on change event.47"""4849def __init__(self):50self.observers = []51self.surface = None52self.data = []53self.cls = None54self.surface_type = 05556def changed(self, event):57"""Notify the observers. """58for observer in self.observers:59observer.update(event, self)6061def add_observer(self, observer):62"""Register an observer. """63self.observers.append(observer)6465def set_surface(self, surface):66self.surface = surface6768def dump_svmlight_file(self, file):69data = np.array(self.data)70X = data[:, 0:2]71y = data[:, 2]72dump_svmlight_file(X, y, file)737475class Controller(object):76def __init__(self, model):77self.model = model78self.kernel = Tk.IntVar()79self.surface_type = Tk.IntVar()80# Whether or not a model has been fitted81self.fitted = False8283def fit(self):84print("fit the model")85train = np.array(self.model.data)86X = train[:, 0:2]87y = train[:, 2]8889C = float(self.complexity.get())90gamma = float(self.gamma.get())91coef0 = float(self.coef0.get())92degree = int(self.degree.get())93kernel_map = {0: "linear", 1: "rbf", 2: "poly"}94if len(np.unique(y)) == 1:95clf = svm.OneClassSVM(kernel=kernel_map[self.kernel.get()],96gamma=gamma, coef0=coef0, degree=degree)97clf.fit(X)98else:99clf = svm.SVC(kernel=kernel_map[self.kernel.get()], C=C,100gamma=gamma, coef0=coef0, degree=degree)101clf.fit(X, y)102if hasattr(clf, 'score'):103print("Accuracy:", clf.score(X, y) * 100)104X1, X2, Z = self.decision_surface(clf)105self.model.clf = clf106self.model.set_surface((X1, X2, Z))107self.model.surface_type = self.surface_type.get()108self.fitted = True109self.model.changed("surface")110111def decision_surface(self, cls):112delta = 1113x = np.arange(x_min, x_max + delta, delta)114y = np.arange(y_min, y_max + delta, delta)115X1, X2 = np.meshgrid(x, y)116Z = cls.decision_function(np.c_[X1.ravel(), X2.ravel()])117Z = Z.reshape(X1.shape)118return X1, X2, Z119120def clear_data(self):121self.model.data = []122self.fitted = False123self.model.changed("clear")124125def add_example(self, x, y, label):126self.model.data.append((x, y, label))127self.model.changed("example_added")128129# update decision surface if already fitted.130self.refit()131132def refit(self):133"""Refit the model if already fitted. """134if self.fitted:135self.fit()136137138class View(object):139"""Test docstring. """140def __init__(self, root, controller):141f = Figure()142ax = f.add_subplot(111)143ax.set_xticks([])144ax.set_yticks([])145ax.set_xlim((x_min, x_max))146ax.set_ylim((y_min, y_max))147canvas = FigureCanvasTkAgg(f, master=root)148canvas.show()149canvas.get_tk_widget().pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)150canvas._tkcanvas.pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)151canvas.mpl_connect('key_press_event', self.onkeypress)152canvas.mpl_connect('key_release_event', self.onkeyrelease)153canvas.mpl_connect('button_press_event', self.onclick)154toolbar = NavigationToolbar2TkAgg(canvas, root)155toolbar.update()156self.shift_down = False157self.controllbar = ControllBar(root, controller)158self.f = f159self.ax = ax160self.canvas = canvas161self.controller = controller162self.contours = []163self.c_labels = None164self.plot_kernels()165166def plot_kernels(self):167self.ax.text(-50, -60, "Linear: $u^T v$")168self.ax.text(-20, -60, "RBF: $\exp (-\gamma \| u-v \|^2)$")169self.ax.text(10, -60, "Poly: $(\gamma \, u^T v + r)^d$")170171def onkeypress(self, event):172if event.key == "shift":173self.shift_down = True174175def onkeyrelease(self, event):176if event.key == "shift":177self.shift_down = False178179def onclick(self, event):180if event.xdata and event.ydata:181if self.shift_down or event.button == 3:182self.controller.add_example(event.xdata, event.ydata, -1)183elif event.button == 1:184self.controller.add_example(event.xdata, event.ydata, 1)185186def update_example(self, model, idx):187x, y, l = model.data[idx]188if l == 1:189color = 'w'190elif l == -1:191color = 'k'192self.ax.plot([x], [y], "%so" % color, scalex=0.0, scaley=0.0)193194def update(self, event, model):195if event == "examples_loaded":196for i in xrange(len(model.data)):197self.update_example(model, i)198199if event == "example_added":200self.update_example(model, -1)201202if event == "clear":203self.ax.clear()204self.ax.set_xticks([])205self.ax.set_yticks([])206self.contours = []207self.c_labels = None208self.plot_kernels()209210if event == "surface":211self.remove_surface()212self.plot_support_vectors(model.clf.support_vectors_)213self.plot_decision_surface(model.surface, model.surface_type)214215self.canvas.draw()216217def remove_surface(self):218"""Remove old decision surface."""219if len(self.contours) > 0:220for contour in self.contours:221if isinstance(contour, ContourSet):222for lineset in contour.collections:223lineset.remove()224else:225contour.remove()226self.contours = []227228def plot_support_vectors(self, support_vectors):229"""Plot the support vectors by placing circles over the230corresponding data points and adds the circle collection231to the contours list."""232cs = self.ax.scatter(support_vectors[:, 0], support_vectors[:, 1],233s=80, edgecolors="k", facecolors="none")234self.contours.append(cs)235236def plot_decision_surface(self, surface, type):237X1, X2, Z = surface238if type == 0:239levels = [-1.0, 0.0, 1.0]240linestyles = ['dashed', 'solid', 'dashed']241colors = 'k'242self.contours.append(self.ax.contour(X1, X2, Z, levels,243colors=colors,244linestyles=linestyles))245elif type == 1:246self.contours.append(self.ax.contourf(X1, X2, Z, 10,247cmap=matplotlib.cm.bone,248origin='lower', alpha=0.85))249self.contours.append(self.ax.contour(X1, X2, Z, [0.0], colors='k',250linestyles=['solid']))251else:252raise ValueError("surface type unknown")253254255class ControllBar(object):256def __init__(self, root, controller):257fm = Tk.Frame(root)258kernel_group = Tk.Frame(fm)259Tk.Radiobutton(kernel_group, text="Linear", variable=controller.kernel,260value=0, command=controller.refit).pack(anchor=Tk.W)261Tk.Radiobutton(kernel_group, text="RBF", variable=controller.kernel,262value=1, command=controller.refit).pack(anchor=Tk.W)263Tk.Radiobutton(kernel_group, text="Poly", variable=controller.kernel,264value=2, command=controller.refit).pack(anchor=Tk.W)265kernel_group.pack(side=Tk.LEFT)266267valbox = Tk.Frame(fm)268controller.complexity = Tk.StringVar()269controller.complexity.set("1.0")270c = Tk.Frame(valbox)271Tk.Label(c, text="C:", anchor="e", width=7).pack(side=Tk.LEFT)272Tk.Entry(c, width=6, textvariable=controller.complexity).pack(273side=Tk.LEFT)274c.pack()275276controller.gamma = Tk.StringVar()277controller.gamma.set("0.01")278g = Tk.Frame(valbox)279Tk.Label(g, text="gamma:", anchor="e", width=7).pack(side=Tk.LEFT)280Tk.Entry(g, width=6, textvariable=controller.gamma).pack(side=Tk.LEFT)281g.pack()282283controller.degree = Tk.StringVar()284controller.degree.set("3")285d = Tk.Frame(valbox)286Tk.Label(d, text="degree:", anchor="e", width=7).pack(side=Tk.LEFT)287Tk.Entry(d, width=6, textvariable=controller.degree).pack(side=Tk.LEFT)288d.pack()289290controller.coef0 = Tk.StringVar()291controller.coef0.set("0")292r = Tk.Frame(valbox)293Tk.Label(r, text="coef0:", anchor="e", width=7).pack(side=Tk.LEFT)294Tk.Entry(r, width=6, textvariable=controller.coef0).pack(side=Tk.LEFT)295r.pack()296valbox.pack(side=Tk.LEFT)297298cmap_group = Tk.Frame(fm)299Tk.Radiobutton(cmap_group, text="Hyperplanes",300variable=controller.surface_type, value=0,301command=controller.refit).pack(anchor=Tk.W)302Tk.Radiobutton(cmap_group, text="Surface",303variable=controller.surface_type, value=1,304command=controller.refit).pack(anchor=Tk.W)305306cmap_group.pack(side=Tk.LEFT)307308train_button = Tk.Button(fm, text='Fit', width=5,309command=controller.fit)310train_button.pack()311fm.pack(side=Tk.LEFT)312Tk.Button(fm, text='Clear', width=5,313command=controller.clear_data).pack(side=Tk.LEFT)314315316def get_parser():317from optparse import OptionParser318op = OptionParser()319op.add_option("--output",320action="store", type="str", dest="output",321help="Path where to dump data.")322return op323324325def main(argv):326op = get_parser()327opts, args = op.parse_args(argv[1:])328root = Tk.Tk()329model = Model()330controller = Controller(model)331root.wm_title("Scikit-learn Libsvm GUI")332view = View(root, controller)333model.add_observer(view)334Tk.mainloop()335336if opts.output:337model.dump_svmlight_file(opts.output)338339if __name__ == "__main__":340main(sys.argv)341342343