📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np1import matplotlib.pyplot as plt2from sklearn.svm import SVC3from sklearn.model_selection import GridSearchCV, train_test_split4from sklearn.datasets import load_iris5import pandas as pd678def plot_cross_val_selection():9iris = load_iris()10X_trainval, X_test, y_trainval, y_test = train_test_split(iris.data,11iris.target,12random_state=0)1314param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],15'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}16grid_search = GridSearchCV(SVC(), param_grid, cv=5,17return_train_score=True)18grid_search.fit(X_trainval, y_trainval)19results = pd.DataFrame(grid_search.cv_results_)[15:]2021best = np.argmax(results.mean_test_score.values)22plt.figure(figsize=(10, 3))23plt.xlim(-1, len(results))24plt.ylim(0, 1.1)25for i, (_, row) in enumerate(results.iterrows()):26scores = row[['split%d_test_score' % i for i in range(5)]]27marker_cv, = plt.plot([i] * 5, scores, '^', c='gray', markersize=5,28alpha=.5)29marker_mean, = plt.plot(i, row.mean_test_score, 'v', c='none', alpha=1,30markersize=10, markeredgecolor='k')31if i == best:32marker_best, = plt.plot(i, row.mean_test_score, 'o', c='red',33fillstyle="none", alpha=1, markersize=20,34markeredgewidth=3)3536plt.xticks(range(len(results)), [str(x).strip("{}").replace("'", "") for x37in grid_search.cv_results_['params']],38rotation=90)39plt.ylabel("Validation accuracy")40plt.xlabel("Parameter settings")41plt.legend([marker_cv, marker_mean, marker_best],42["cv accuracy", "mean accuracy", "best parameter setting"],43loc=(1.05, .4))444546def plot_grid_search_overview():47plt.figure(figsize=(10, 3), dpi=70)48axes = plt.gca()49axes.yaxis.set_visible(False)50axes.xaxis.set_visible(False)51axes.set_frame_on(False)5253def draw(ax, text, start, target=None):54if target is not None:55patchB = target.get_bbox_patch()56end = target.get_position()57else:58end = start59patchB = None60annotation = ax.annotate(text, end, start, xycoords='axes pixels',61textcoords='axes pixels', size=20,62arrowprops=dict(63arrowstyle="-|>", fc="w", ec="k",64patchB=patchB,65connectionstyle="arc3,rad=0.0"),66bbox=dict(boxstyle="round", fc="w"),67horizontalalignment="center",68verticalalignment="center")69plt.draw()70return annotation7172step = 10073grr = 4007475final_evaluation = draw(axes, "final evaluation", (5 * step, grr - 3 *76step))77retrained_model = draw(axes, "retrained model", (3 * step, grr - 3 * step),78final_evaluation)79best_parameters = draw(axes, "best parameters", (.5 * step, grr - 3 *80step), retrained_model)81cross_validation = draw(axes, "cross-validation", (.5 * step, grr - 2 *82step), best_parameters)83draw(axes, "parameter grid", (0.0, grr - 0), cross_validation)84training_data = draw(axes, "training data", (2 * step, grr - step),85cross_validation)86draw(axes, "training data", (2 * step, grr - step), retrained_model)87test_data = draw(axes, "test data", (5 * step, grr - step),88final_evaluation)89draw(axes, "data set", (3.5 * step, grr - 0.0), training_data)90draw(axes, "data set", (3.5 * step, grr - 0.0), test_data)91plt.ylim(0, 1)92plt.xlim(0, 1.5)939495