📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np1import matplotlib.pyplot as plt23from sklearn.tree import DecisionTreeClassifier45from sklearn.externals.six import StringIO # doctest: +SKIP6from sklearn.tree import export_graphviz7from imageio import imread8from scipy import ndimage9from sklearn.datasets import make_moons1011import re1213from .tools import discrete_scatter14from .plot_helpers import cm2151617def tree_image(tree, fout=None):18try:19import graphviz20except ImportError:21# make a hacky white plot22x = np.ones((10, 10))23x[0, 0] = 024return x25dot_data = StringIO()26export_graphviz(tree, out_file=dot_data, max_depth=3, impurity=False)27data = dot_data.getvalue()28data = re.sub(r"samples = [0-9]+\\n", "", data)29data = re.sub(r"\\nsamples = [0-9]+", "", data)30data = re.sub(r"value", "counts", data)3132graph = graphviz.Source(data, format="png")33if fout is None:34fout = "tmp"35graph.render(fout)36return imread(fout + ".png")373839def plot_tree_progressive():40X, y = make_moons(n_samples=100, noise=0.25, random_state=3)41plt.figure()42ax = plt.gca()43discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)44ax.set_xlabel("Feature 0")45ax.set_ylabel("Feature 1")46plt.legend(["Class 0", "Class 1"], loc='best')4748axes = []49for i in range(3):50fig, ax = plt.subplots(1, 2, figsize=(12, 4),51subplot_kw={'xticks': (), 'yticks': ()})52axes.append(ax)53axes = np.array(axes)5455for i, max_depth in enumerate([1, 2, 9]):56tree = plot_tree(X, y, max_depth=max_depth, ax=axes[i, 0])57axes[i, 1].imshow(tree_image(tree))58axes[i, 1].set_axis_off()596061def plot_tree_partition(X, y, tree, ax=None):62if ax is None:63ax = plt.gca()64eps = X.std() / 2.6566x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps67y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps68xx = np.linspace(x_min, x_max, 1000)69yy = np.linspace(y_min, y_max, 1000)7071X1, X2 = np.meshgrid(xx, yy)72X_grid = np.c_[X1.ravel(), X2.ravel()]7374Z = tree.predict(X_grid)75Z = Z.reshape(X1.shape)76faces = tree.apply(X_grid)77faces = faces.reshape(X1.shape)78border = ndimage.laplace(faces) != 079ax.contourf(X1, X2, Z, alpha=.4, cmap=cm2, levels=[0, .5, 1])80ax.scatter(X1[border], X2[border], marker='.', s=1)8182discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)83ax.set_xlim(x_min, x_max)84ax.set_ylim(y_min, y_max)85ax.set_xticks(())86ax.set_yticks(())87return ax888990def plot_tree(X, y, max_depth=1, ax=None):91tree = DecisionTreeClassifier(max_depth=max_depth, random_state=0).fit(X, y)92ax = plot_tree_partition(X, y, tree, ax=ax)93ax.set_title("depth = %d" % max_depth)94return tree959697