Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132923 views
License: OTHER
1
import numpy as np
2
import matplotlib.pyplot as plt
3
4
5
def plot_group_kfold():
6
from sklearn.model_selection import GroupKFold
7
groups = [0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3]
8
9
plt.figure(figsize=(10, 2))
10
plt.title("GroupKFold")
11
12
axes = plt.gca()
13
axes.set_frame_on(False)
14
15
n_folds = 12
16
n_samples = 12
17
n_iter = 3
18
n_samples_per_fold = 1
19
20
cv = GroupKFold(n_splits=3)
21
mask = np.zeros((n_iter, n_samples))
22
for i, (train, test) in enumerate(cv.split(range(12), groups=groups)):
23
mask[i, train] = 1
24
mask[i, test] = 2
25
26
for i in range(n_folds):
27
# test is grey
28
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
29
# not selected has no hatch
30
31
boxes = axes.barh(y=range(n_iter), width=[1 - 0.1] * n_iter,
32
left=i * n_samples_per_fold, height=.6, color=colors,
33
hatch="//", edgecolor="k", align='edge')
34
for j in np.where(mask[:, i] == 0)[0]:
35
boxes[j].set_hatch("")
36
37
axes.barh(y=[n_iter] * n_folds, width=[1 - 0.1] * n_folds,
38
left=np.arange(n_folds) * n_samples_per_fold, height=.6,
39
color="w", edgecolor='k', align="edge")
40
41
for i in range(12):
42
axes.text((i + .5) * n_samples_per_fold, 3.5, "%d" %
43
groups[i], horizontalalignment="center")
44
45
axes.invert_yaxis()
46
axes.set_xlim(0, n_samples + 1)
47
axes.set_ylabel("CV iterations")
48
axes.set_xlabel("Data points")
49
axes.set_xticks(np.arange(n_samples) + .5)
50
axes.set_xticklabels(np.arange(1, n_samples + 1))
51
axes.set_yticks(np.arange(n_iter + 1) + .3)
52
axes.set_yticklabels(
53
["Split %d" % x for x in range(1, n_iter + 1)] + ["Group"])
54
plt.legend([boxes[0], boxes[1]], ["Training set", "Test set"], loc=(1, .3))
55
plt.tight_layout()
56
57
58
def plot_shuffle_split():
59
from sklearn.model_selection import ShuffleSplit
60
plt.figure(figsize=(10, 2))
61
plt.title("ShuffleSplit with 10 points"
62
", train_size=5, test_size=2, n_splits=4")
63
64
axes = plt.gca()
65
axes.set_frame_on(False)
66
67
n_folds = 10
68
n_samples = 10
69
n_iter = 4
70
n_samples_per_fold = 1
71
72
ss = ShuffleSplit(n_splits=4, train_size=5, test_size=2, random_state=43)
73
mask = np.zeros((n_iter, n_samples))
74
for i, (train, test) in enumerate(ss.split(range(10))):
75
mask[i, train] = 1
76
mask[i, test] = 2
77
78
for i in range(n_folds):
79
# test is grey
80
colors = ["grey" if x == 2 else "white" for x in mask[:, i]]
81
# not selected has no hatch
82
83
boxes = axes.barh(y=range(n_iter), width=[1 - 0.1] * n_iter,
84
left=i * n_samples_per_fold, height=.6, color=colors,
85
hatch="//", edgecolor='k', align='edge')
86
for j in np.where(mask[:, i] == 0)[0]:
87
boxes[j].set_hatch("")
88
89
axes.invert_yaxis()
90
axes.set_xlim(0, n_samples + 1)
91
axes.set_ylabel("CV iterations")
92
axes.set_xlabel("Data points")
93
axes.set_xticks(np.arange(n_samples) + .5)
94
axes.set_xticklabels(np.arange(1, n_samples + 1))
95
axes.set_yticks(np.arange(n_iter) + .3)
96
axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)])
97
# legend hacked for this random state
98
plt.legend([boxes[1], boxes[0], boxes[2]], [
99
"Training set", "Test set", "Not selected"], loc=(1, .3))
100
plt.tight_layout()
101
102
103
def plot_stratified_cross_validation():
104
fig, both_axes = plt.subplots(2, 1, figsize=(12, 5))
105
# plt.title("cross_validation_not_stratified")
106
axes = both_axes[0]
107
axes.set_title("Standard cross-validation with sorted class labels")
108
109
axes.set_frame_on(False)
110
111
n_folds = 3
112
n_samples = 150
113
114
n_samples_per_fold = n_samples / float(n_folds)
115
116
for i in range(n_folds):
117
colors = ["w"] * n_folds
118
colors[i] = "grey"
119
axes.barh(y=range(n_folds), width=[n_samples_per_fold - 1] *
120
n_folds, left=i * n_samples_per_fold, height=.6,
121
color=colors, hatch="//", edgecolor='k', align='edge')
122
123
axes.barh(y=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
124
n_folds, left=np.arange(3) * n_samples_per_fold, height=.6,
125
color="w", edgecolor='k', align='edge')
126
127
axes.invert_yaxis()
128
axes.set_xlim(0, n_samples + 1)
129
axes.set_ylabel("CV iterations")
130
axes.set_xlabel("Data points")
131
axes.set_xticks(np.arange(n_samples_per_fold / 2.,
132
n_samples, n_samples_per_fold))
133
axes.set_xticklabels(["Fold %d" % x for x in range(1, n_folds + 1)])
134
axes.set_yticks(np.arange(n_folds + 1) + .3)
135
axes.set_yticklabels(
136
["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])
137
for i in range(3):
138
axes.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %
139
i, horizontalalignment="center")
140
141
ax = both_axes[1]
142
ax.set_title("Stratified Cross-validation")
143
ax.set_frame_on(False)
144
ax.invert_yaxis()
145
ax.set_xlim(0, n_samples + 1)
146
ax.set_ylabel("CV iterations")
147
ax.set_xlabel("Data points")
148
149
ax.set_yticks(np.arange(n_folds + 1) + .3)
150
ax.set_yticklabels(
151
["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])
152
153
n_subsplit = n_samples_per_fold / 3.
154
for i in range(n_folds):
155
test_bars = ax.barh(
156
y=[i] * n_folds, width=[n_subsplit - 1] * n_folds,
157
left=np.arange(n_folds) * n_samples_per_fold + i * n_subsplit,
158
height=.6, color="grey", hatch="//", edgecolor='k', align='edge')
159
160
w = 2 * n_subsplit - 1
161
ax.barh(y=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds)
162
* n_samples_per_fold + (0 + 1) * n_subsplit, height=.6, color="w",
163
hatch="//", edgecolor='k', align='edge')
164
ax.barh(y=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.],
165
left=np.maximum(0, np.arange(n_folds + 1) * n_samples_per_fold -
166
n_subsplit), height=.6, color="w", hatch="//",
167
edgecolor='k', align='edge')
168
training_bars = ax.barh(y=[2] * n_folds, width=[w] * n_folds,
169
left=np.arange(n_folds) * n_samples_per_fold,
170
height=.6, color="w", hatch="//", edgecolor='k',
171
align='edge')
172
173
ax.barh(y=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *
174
n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6,
175
color="w", edgecolor='k', align='edge')
176
177
for i in range(3):
178
ax.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %
179
i, horizontalalignment="center")
180
ax.set_ylim(4, -0.1)
181
plt.legend([training_bars[0], test_bars[0]], [
182
'Training data', 'Test data'], loc=(1.05, 1), frameon=False)
183
184
fig.tight_layout()
185
186
187
def plot_cross_validation():
188
plt.figure(figsize=(12, 2))
189
plt.title("cross_validation")
190
axes = plt.gca()
191
axes.set_frame_on(False)
192
193
n_folds = 5
194
n_samples = 25
195
196
n_samples_per_fold = n_samples / float(n_folds)
197
198
for i in range(n_folds):
199
colors = ["w"] * n_folds
200
colors[i] = "grey"
201
bars = plt.barh(
202
y=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,
203
left=i * n_samples_per_fold, height=.6, color=colors, hatch="//",
204
edgecolor='k', align='edge')
205
axes.invert_yaxis()
206
axes.set_xlim(0, n_samples + 1)
207
plt.ylabel("CV iterations")
208
plt.xlabel("Data points")
209
plt.xticks(np.arange(n_samples_per_fold / 2., n_samples,
210
n_samples_per_fold),
211
["Fold %d" % x for x in range(1, n_folds + 1)])
212
plt.yticks(np.arange(n_folds) + .3,
213
["Split %d" % x for x in range(1, n_folds + 1)])
214
plt.legend([bars[0], bars[4]], ['Training data', 'Test data'],
215
loc=(1.05, 0.4), frameon=False)
216
217
218
def plot_threefold_split():
219
plt.figure(figsize=(15, 1))
220
axis = plt.gca()
221
bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9], left=[0, 12, 15], color=[
222
'white', 'grey', 'grey'], hatch="//", edgecolor='k',
223
align='edge')
224
bars[2].set_hatch(r"")
225
axis.set_yticks(())
226
axis.set_frame_on(False)
227
axis.set_ylim(-.1, .8)
228
axis.set_xlim(-0.1, 20.1)
229
axis.set_xticks([6, 13.3, 17.5])
230
axis.set_xticklabels(["training set", "validation set",
231
"test set"], fontdict={'fontsize': 20})
232
axis.tick_params(length=0, labeltop=True, labelbottom=False)
233
axis.text(6, -.3, "Model fitting",
234
fontdict={'fontsize': 13}, horizontalalignment="center")
235
axis.text(13.3, -.3, "Parameter selection",
236
fontdict={'fontsize': 13}, horizontalalignment="center")
237
axis.text(17.5, -.3, "Evaluation",
238
fontdict={'fontsize': 13}, horizontalalignment="center")
239
240