Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132923 views
License: OTHER
1
import numpy as np
2
import matplotlib.pyplot as plt
3
from sklearn.svm import SVC
4
from sklearn.model_selection import GridSearchCV, train_test_split
5
from sklearn.datasets import load_iris
6
import pandas as pd
7
8
9
def plot_cross_val_selection():
10
iris = load_iris()
11
X_trainval, X_test, y_trainval, y_test = train_test_split(iris.data,
12
iris.target,
13
random_state=0)
14
15
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],
16
'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}
17
grid_search = GridSearchCV(SVC(), param_grid, cv=5,
18
return_train_score=True)
19
grid_search.fit(X_trainval, y_trainval)
20
results = pd.DataFrame(grid_search.cv_results_)[15:]
21
22
best = np.argmax(results.mean_test_score.values)
23
plt.figure(figsize=(10, 3))
24
plt.xlim(-1, len(results))
25
plt.ylim(0, 1.1)
26
for i, (_, row) in enumerate(results.iterrows()):
27
scores = row[['split%d_test_score' % i for i in range(5)]]
28
marker_cv, = plt.plot([i] * 5, scores, '^', c='gray', markersize=5,
29
alpha=.5)
30
marker_mean, = plt.plot(i, row.mean_test_score, 'v', c='none', alpha=1,
31
markersize=10, markeredgecolor='k')
32
if i == best:
33
marker_best, = plt.plot(i, row.mean_test_score, 'o', c='red',
34
fillstyle="none", alpha=1, markersize=20,
35
markeredgewidth=3)
36
37
plt.xticks(range(len(results)), [str(x).strip("{}").replace("'", "") for x
38
in grid_search.cv_results_['params']],
39
rotation=90)
40
plt.ylabel("Validation accuracy")
41
plt.xlabel("Parameter settings")
42
plt.legend([marker_cv, marker_mean, marker_best],
43
["cv accuracy", "mean accuracy", "best parameter setting"],
44
loc=(1.05, .4))
45
46
47
def plot_grid_search_overview():
48
plt.figure(figsize=(10, 3), dpi=70)
49
axes = plt.gca()
50
axes.yaxis.set_visible(False)
51
axes.xaxis.set_visible(False)
52
axes.set_frame_on(False)
53
54
def draw(ax, text, start, target=None):
55
if target is not None:
56
patchB = target.get_bbox_patch()
57
end = target.get_position()
58
else:
59
end = start
60
patchB = None
61
annotation = ax.annotate(text, end, start, xycoords='axes pixels',
62
textcoords='axes pixels', size=20,
63
arrowprops=dict(
64
arrowstyle="-|>", fc="w", ec="k",
65
patchB=patchB,
66
connectionstyle="arc3,rad=0.0"),
67
bbox=dict(boxstyle="round", fc="w"),
68
horizontalalignment="center",
69
verticalalignment="center")
70
plt.draw()
71
return annotation
72
73
step = 100
74
grr = 400
75
76
final_evaluation = draw(axes, "final evaluation", (5 * step, grr - 3 *
77
step))
78
retrained_model = draw(axes, "retrained model", (3 * step, grr - 3 * step),
79
final_evaluation)
80
best_parameters = draw(axes, "best parameters", (.5 * step, grr - 3 *
81
step), retrained_model)
82
cross_validation = draw(axes, "cross-validation", (.5 * step, grr - 2 *
83
step), best_parameters)
84
draw(axes, "parameter grid", (0.0, grr - 0), cross_validation)
85
training_data = draw(axes, "training data", (2 * step, grr - step),
86
cross_validation)
87
draw(axes, "training data", (2 * step, grr - step), retrained_model)
88
test_data = draw(axes, "test data", (5 * step, grr - step),
89
final_evaluation)
90
draw(axes, "data set", (3.5 * step, grr - 0.0), training_data)
91
draw(axes, "data set", (3.5 * step, grr - 0.0), test_data)
92
plt.ylim(0, 1)
93
plt.xlim(0, 1.5)
94
95