Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
import numpy as np
2
import matplotlib.pyplot as plt
3
4
from sklearn.tree import DecisionTreeClassifier
5
6
from sklearn.externals.six import StringIO # doctest: +SKIP
7
from sklearn.tree import export_graphviz
8
from imageio import imread
9
from scipy import ndimage
10
from sklearn.datasets import make_moons
11
12
import re
13
14
from .tools import discrete_scatter
15
from .plot_helpers import cm2
16
17
18
def tree_image(tree, fout=None):
19
try:
20
import graphviz
21
except ImportError:
22
# make a hacky white plot
23
x = np.ones((10, 10))
24
x[0, 0] = 0
25
return x
26
dot_data = StringIO()
27
export_graphviz(tree, out_file=dot_data, max_depth=3, impurity=False)
28
data = dot_data.getvalue()
29
data = re.sub(r"samples = [0-9]+\\n", "", data)
30
data = re.sub(r"\\nsamples = [0-9]+", "", data)
31
data = re.sub(r"value", "counts", data)
32
33
graph = graphviz.Source(data, format="png")
34
if fout is None:
35
fout = "tmp"
36
graph.render(fout)
37
return imread(fout + ".png")
38
39
40
def plot_tree_progressive():
41
X, y = make_moons(n_samples=100, noise=0.25, random_state=3)
42
plt.figure()
43
ax = plt.gca()
44
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
45
ax.set_xlabel("Feature 0")
46
ax.set_ylabel("Feature 1")
47
plt.legend(["Class 0", "Class 1"], loc='best')
48
49
axes = []
50
for i in range(3):
51
fig, ax = plt.subplots(1, 2, figsize=(12, 4),
52
subplot_kw={'xticks': (), 'yticks': ()})
53
axes.append(ax)
54
axes = np.array(axes)
55
56
for i, max_depth in enumerate([1, 2, 9]):
57
tree = plot_tree(X, y, max_depth=max_depth, ax=axes[i, 0])
58
axes[i, 1].imshow(tree_image(tree))
59
axes[i, 1].set_axis_off()
60
61
62
def plot_tree_partition(X, y, tree, ax=None):
63
if ax is None:
64
ax = plt.gca()
65
eps = X.std() / 2.
66
67
x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps
68
y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps
69
xx = np.linspace(x_min, x_max, 1000)
70
yy = np.linspace(y_min, y_max, 1000)
71
72
X1, X2 = np.meshgrid(xx, yy)
73
X_grid = np.c_[X1.ravel(), X2.ravel()]
74
75
Z = tree.predict(X_grid)
76
Z = Z.reshape(X1.shape)
77
faces = tree.apply(X_grid)
78
faces = faces.reshape(X1.shape)
79
border = ndimage.laplace(faces) != 0
80
ax.contourf(X1, X2, Z, alpha=.4, cmap=cm2, levels=[0, .5, 1])
81
ax.scatter(X1[border], X2[border], marker='.', s=1)
82
83
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
84
ax.set_xlim(x_min, x_max)
85
ax.set_ylim(y_min, y_max)
86
ax.set_xticks(())
87
ax.set_yticks(())
88
return ax
89
90
91
def plot_tree(X, y, max_depth=1, ax=None):
92
tree = DecisionTreeClassifier(max_depth=max_depth, random_state=0).fit(X, y)
93
ax = plot_tree_partition(X, y, tree, ax=ax)
94
ax.set_title("depth = %d" % max_depth)
95
return tree
96
97