📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np1import matplotlib.pyplot as plt234def plot_group_kfold():5from sklearn.model_selection import GroupKFold6groups = [0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3]78plt.figure(figsize=(10, 2))9plt.title("GroupKFold")1011axes = plt.gca()12axes.set_frame_on(False)1314n_folds = 1215n_samples = 1216n_iter = 317n_samples_per_fold = 11819cv = GroupKFold(n_splits=3)20mask = np.zeros((n_iter, n_samples))21for i, (train, test) in enumerate(cv.split(range(12), groups=groups)):22mask[i, train] = 123mask[i, test] = 22425for i in range(n_folds):26# test is grey27colors = ["grey" if x == 2 else "white" for x in mask[:, i]]28# not selected has no hatch2930boxes = axes.barh(y=range(n_iter), width=[1 - 0.1] * n_iter,31left=i * n_samples_per_fold, height=.6, color=colors,32hatch="//", edgecolor="k", align='edge')33for j in np.where(mask[:, i] == 0)[0]:34boxes[j].set_hatch("")3536axes.barh(y=[n_iter] * n_folds, width=[1 - 0.1] * n_folds,37left=np.arange(n_folds) * n_samples_per_fold, height=.6,38color="w", edgecolor='k', align="edge")3940for i in range(12):41axes.text((i + .5) * n_samples_per_fold, 3.5, "%d" %42groups[i], horizontalalignment="center")4344axes.invert_yaxis()45axes.set_xlim(0, n_samples + 1)46axes.set_ylabel("CV iterations")47axes.set_xlabel("Data points")48axes.set_xticks(np.arange(n_samples) + .5)49axes.set_xticklabels(np.arange(1, n_samples + 1))50axes.set_yticks(np.arange(n_iter + 1) + .3)51axes.set_yticklabels(52["Split %d" % x for x in range(1, n_iter + 1)] + ["Group"])53plt.legend([boxes[0], boxes[1]], ["Training set", "Test set"], loc=(1, .3))54plt.tight_layout()555657def plot_shuffle_split():58from sklearn.model_selection import ShuffleSplit59plt.figure(figsize=(10, 2))60plt.title("ShuffleSplit with 10 points"61", train_size=5, test_size=2, n_splits=4")6263axes = plt.gca()64axes.set_frame_on(False)6566n_folds = 1067n_samples = 1068n_iter = 469n_samples_per_fold = 17071ss = ShuffleSplit(n_splits=4, train_size=5, test_size=2, random_state=43)72mask = np.zeros((n_iter, n_samples))73for i, (train, test) in enumerate(ss.split(range(10))):74mask[i, train] = 175mask[i, test] = 27677for i in range(n_folds):78# test is grey79colors = ["grey" if x == 2 else "white" for x in mask[:, i]]80# not selected has no hatch8182boxes = axes.barh(y=range(n_iter), width=[1 - 0.1] * n_iter,83left=i * n_samples_per_fold, height=.6, color=colors,84hatch="//", edgecolor='k', align='edge')85for j in np.where(mask[:, i] == 0)[0]:86boxes[j].set_hatch("")8788axes.invert_yaxis()89axes.set_xlim(0, n_samples + 1)90axes.set_ylabel("CV iterations")91axes.set_xlabel("Data points")92axes.set_xticks(np.arange(n_samples) + .5)93axes.set_xticklabels(np.arange(1, n_samples + 1))94axes.set_yticks(np.arange(n_iter) + .3)95axes.set_yticklabels(["Split %d" % x for x in range(1, n_iter + 1)])96# legend hacked for this random state97plt.legend([boxes[1], boxes[0], boxes[2]], [98"Training set", "Test set", "Not selected"], loc=(1, .3))99plt.tight_layout()100101102def plot_stratified_cross_validation():103fig, both_axes = plt.subplots(2, 1, figsize=(12, 5))104# plt.title("cross_validation_not_stratified")105axes = both_axes[0]106axes.set_title("Standard cross-validation with sorted class labels")107108axes.set_frame_on(False)109110n_folds = 3111n_samples = 150112113n_samples_per_fold = n_samples / float(n_folds)114115for i in range(n_folds):116colors = ["w"] * n_folds117colors[i] = "grey"118axes.barh(y=range(n_folds), width=[n_samples_per_fold - 1] *119n_folds, left=i * n_samples_per_fold, height=.6,120color=colors, hatch="//", edgecolor='k', align='edge')121122axes.barh(y=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *123n_folds, left=np.arange(3) * n_samples_per_fold, height=.6,124color="w", edgecolor='k', align='edge')125126axes.invert_yaxis()127axes.set_xlim(0, n_samples + 1)128axes.set_ylabel("CV iterations")129axes.set_xlabel("Data points")130axes.set_xticks(np.arange(n_samples_per_fold / 2.,131n_samples, n_samples_per_fold))132axes.set_xticklabels(["Fold %d" % x for x in range(1, n_folds + 1)])133axes.set_yticks(np.arange(n_folds + 1) + .3)134axes.set_yticklabels(135["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])136for i in range(3):137axes.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %138i, horizontalalignment="center")139140ax = both_axes[1]141ax.set_title("Stratified Cross-validation")142ax.set_frame_on(False)143ax.invert_yaxis()144ax.set_xlim(0, n_samples + 1)145ax.set_ylabel("CV iterations")146ax.set_xlabel("Data points")147148ax.set_yticks(np.arange(n_folds + 1) + .3)149ax.set_yticklabels(150["Split %d" % x for x in range(1, n_folds + 1)] + ["Class label"])151152n_subsplit = n_samples_per_fold / 3.153for i in range(n_folds):154test_bars = ax.barh(155y=[i] * n_folds, width=[n_subsplit - 1] * n_folds,156left=np.arange(n_folds) * n_samples_per_fold + i * n_subsplit,157height=.6, color="grey", hatch="//", edgecolor='k', align='edge')158159w = 2 * n_subsplit - 1160ax.barh(y=[0] * n_folds, width=[w] * n_folds, left=np.arange(n_folds)161* n_samples_per_fold + (0 + 1) * n_subsplit, height=.6, color="w",162hatch="//", edgecolor='k', align='edge')163ax.barh(y=[1] * (n_folds + 1), width=[w / 2., w, w, w / 2.],164left=np.maximum(0, np.arange(n_folds + 1) * n_samples_per_fold -165n_subsplit), height=.6, color="w", hatch="//",166edgecolor='k', align='edge')167training_bars = ax.barh(y=[2] * n_folds, width=[w] * n_folds,168left=np.arange(n_folds) * n_samples_per_fold,169height=.6, color="w", hatch="//", edgecolor='k',170align='edge')171172ax.barh(y=[n_folds] * n_folds, width=[n_samples_per_fold - 1] *173n_folds, left=np.arange(n_folds) * n_samples_per_fold, height=.6,174color="w", edgecolor='k', align='edge')175176for i in range(3):177ax.text((i + .5) * n_samples_per_fold, 3.5, "Class %d" %178i, horizontalalignment="center")179ax.set_ylim(4, -0.1)180plt.legend([training_bars[0], test_bars[0]], [181'Training data', 'Test data'], loc=(1.05, 1), frameon=False)182183fig.tight_layout()184185186def plot_cross_validation():187plt.figure(figsize=(12, 2))188plt.title("cross_validation")189axes = plt.gca()190axes.set_frame_on(False)191192n_folds = 5193n_samples = 25194195n_samples_per_fold = n_samples / float(n_folds)196197for i in range(n_folds):198colors = ["w"] * n_folds199colors[i] = "grey"200bars = plt.barh(201y=range(n_folds), width=[n_samples_per_fold - 0.1] * n_folds,202left=i * n_samples_per_fold, height=.6, color=colors, hatch="//",203edgecolor='k', align='edge')204axes.invert_yaxis()205axes.set_xlim(0, n_samples + 1)206plt.ylabel("CV iterations")207plt.xlabel("Data points")208plt.xticks(np.arange(n_samples_per_fold / 2., n_samples,209n_samples_per_fold),210["Fold %d" % x for x in range(1, n_folds + 1)])211plt.yticks(np.arange(n_folds) + .3,212["Split %d" % x for x in range(1, n_folds + 1)])213plt.legend([bars[0], bars[4]], ['Training data', 'Test data'],214loc=(1.05, 0.4), frameon=False)215216217def plot_threefold_split():218plt.figure(figsize=(15, 1))219axis = plt.gca()220bars = axis.barh([0, 0, 0], [11.9, 2.9, 4.9], left=[0, 12, 15], color=[221'white', 'grey', 'grey'], hatch="//", edgecolor='k',222align='edge')223bars[2].set_hatch(r"")224axis.set_yticks(())225axis.set_frame_on(False)226axis.set_ylim(-.1, .8)227axis.set_xlim(-0.1, 20.1)228axis.set_xticks([6, 13.3, 17.5])229axis.set_xticklabels(["training set", "validation set",230"test set"], fontdict={'fontsize': 20})231axis.tick_params(length=0, labeltop=True, labelbottom=False)232axis.text(6, -.3, "Model fitting",233fontdict={'fontsize': 13}, horizontalalignment="center")234axis.text(13.3, -.3, "Parameter selection",235fontdict={'fontsize': 13}, horizontalalignment="center")236axis.text(17.5, -.3, "Evaluation",237fontdict={'fontsize': 13}, horizontalalignment="center")238239240