📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np1import matplotlib.pyplot as plt2import warnings345def plot_venn_diagram():6fig, ax = plt.subplots(subplot_kw=dict(frameon=False, xticks=[], yticks=[]))7ax.add_patch(plt.Circle((0.3, 0.3), 0.3, fc='red', alpha=0.5))8ax.add_patch(plt.Circle((0.6, 0.3), 0.3, fc='blue', alpha=0.5))9ax.add_patch(plt.Rectangle((-0.1, -0.1), 1.1, 0.8, fc='none', ec='black'))10ax.text(0.2, 0.3, '$x$', size=30, ha='center', va='center')11ax.text(0.7, 0.3, '$y$', size=30, ha='center', va='center')12ax.text(0.0, 0.6, '$I$', size=30)13ax.axis('equal')141516def plot_example_decision_tree():17fig = plt.figure(figsize=(10, 4))18ax = fig.add_axes([0, 0, 0.8, 1], frameon=False, xticks=[], yticks=[])19ax.set_title('Example Decision Tree: Animal Classification', size=24)2021def text(ax, x, y, t, size=20, **kwargs):22ax.text(x, y, t,23ha='center', va='center', size=size,24bbox=dict(boxstyle='round', ec='k', fc='w'), **kwargs)2526text(ax, 0.5, 0.9, "How big is\nthe animal?", 20)27text(ax, 0.3, 0.6, "Does the animal\nhave horns?", 18)28text(ax, 0.7, 0.6, "Does the animal\nhave two legs?", 18)29text(ax, 0.12, 0.3, "Are the horns\nlonger than 10cm?", 14)30text(ax, 0.38, 0.3, "Is the animal\nwearing a collar?", 14)31text(ax, 0.62, 0.3, "Does the animal\nhave wings?", 14)32text(ax, 0.88, 0.3, "Does the animal\nhave a tail?", 14)3334text(ax, 0.4, 0.75, "> 1m", 12, alpha=0.4)35text(ax, 0.6, 0.75, "< 1m", 12, alpha=0.4)3637text(ax, 0.21, 0.45, "yes", 12, alpha=0.4)38text(ax, 0.34, 0.45, "no", 12, alpha=0.4)3940text(ax, 0.66, 0.45, "yes", 12, alpha=0.4)41text(ax, 0.79, 0.45, "no", 12, alpha=0.4)4243ax.plot([0.3, 0.5, 0.7], [0.6, 0.9, 0.6], '-k')44ax.plot([0.12, 0.3, 0.38], [0.3, 0.6, 0.3], '-k')45ax.plot([0.62, 0.7, 0.88], [0.3, 0.6, 0.3], '-k')46ax.plot([0.0, 0.12, 0.20], [0.0, 0.3, 0.0], '--k')47ax.plot([0.28, 0.38, 0.48], [0.0, 0.3, 0.0], '--k')48ax.plot([0.52, 0.62, 0.72], [0.0, 0.3, 0.0], '--k')49ax.plot([0.8, 0.88, 1.0], [0.0, 0.3, 0.0], '--k')50ax.axis([0, 1, 0, 1])515253def visualize_tree(estimator, X, y, boundaries=True,54xlim=None, ylim=None):55estimator.fit(X, y)5657if xlim is None:58xlim = (X[:, 0].min() - 0.1, X[:, 0].max() + 0.1)59if ylim is None:60ylim = (X[:, 1].min() - 0.1, X[:, 1].max() + 0.1)6162x_min, x_max = xlim63y_min, y_max = ylim64xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),65np.linspace(y_min, y_max, 100))66Z = estimator.predict(np.c_[xx.ravel(), yy.ravel()])6768# Put the result into a color plot69Z = Z.reshape(xx.shape)70plt.figure()71plt.pcolormesh(xx, yy, Z, alpha=0.2, cmap='rainbow')72plt.clim(y.min(), y.max())7374# Plot also the training points75plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='rainbow')76plt.axis('off')7778plt.xlim(x_min, x_max)79plt.ylim(y_min, y_max)80plt.clim(y.min(), y.max())8182# Plot the decision boundaries83def plot_boundaries(i, xlim, ylim):84if i < 0:85return8687tree = estimator.tree_8889if tree.feature[i] == 0:90plt.plot([tree.threshold[i], tree.threshold[i]], ylim, '-k')91plot_boundaries(tree.children_left[i],92[xlim[0], tree.threshold[i]], ylim)93plot_boundaries(tree.children_right[i],94[tree.threshold[i], xlim[1]], ylim)9596elif tree.feature[i] == 1:97plt.plot(xlim, [tree.threshold[i], tree.threshold[i]], '-k')98plot_boundaries(tree.children_left[i], xlim,99[ylim[0], tree.threshold[i]])100plot_boundaries(tree.children_right[i], xlim,101[tree.threshold[i], ylim[1]])102103if boundaries:104plot_boundaries(0, plt.xlim(), plt.ylim())105106107def plot_tree_interactive(X, y):108from sklearn.tree import DecisionTreeClassifier109110def interactive_tree(depth=1):111clf = DecisionTreeClassifier(max_depth=depth, random_state=0)112visualize_tree(clf, X, y)113114from IPython.html.widgets import interact115return interact(interactive_tree, depth=[1, 5])116117118def plot_kmeans_interactive(min_clusters=1, max_clusters=6):119from IPython.html.widgets import interact120from sklearn.metrics.pairwise import euclidean_distances121from sklearn.datasets.samples_generator import make_blobs122123with warnings.catch_warnings():124warnings.filterwarnings('ignore')125126X, y = make_blobs(n_samples=300, centers=4,127random_state=0, cluster_std=0.60)128129def _kmeans_step(frame=0, n_clusters=4):130rng = np.random.RandomState(2)131labels = np.zeros(X.shape[0])132centers = rng.randn(n_clusters, 2)133134nsteps = frame // 3135136for i in range(nsteps + 1):137old_centers = centers138if i < nsteps or frame % 3 > 0:139dist = euclidean_distances(X, centers)140labels = dist.argmin(1)141142if i < nsteps or frame % 3 > 1:143centers = np.array([X[labels == j].mean(0)144for j in range(n_clusters)])145nans = np.isnan(centers)146centers[nans] = old_centers[nans]147148149# plot the data and cluster centers150plt.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='rainbow',151vmin=0, vmax=n_clusters - 1);152plt.scatter(old_centers[:, 0], old_centers[:, 1], marker='o',153c=np.arange(n_clusters),154s=200, cmap='rainbow')155plt.scatter(old_centers[:, 0], old_centers[:, 1], marker='o',156c='black', s=50)157158# plot new centers if third frame159if frame % 3 == 2:160for i in range(n_clusters):161plt.annotate('', centers[i], old_centers[i],162arrowprops=dict(arrowstyle='->', linewidth=1))163plt.scatter(centers[:, 0], centers[:, 1], marker='o',164c=np.arange(n_clusters),165s=200, cmap='rainbow')166plt.scatter(centers[:, 0], centers[:, 1], marker='o',167c='black', s=50)168169plt.xlim(-4, 4)170plt.ylim(-2, 10)171172if frame % 3 == 1:173plt.text(3.8, 9.5, "1. Reassign points to nearest centroid",174ha='right', va='top', size=14)175elif frame % 3 == 2:176plt.text(3.8, 9.5, "2. Update centroids to cluster means",177ha='right', va='top', size=14)178179180return interact(_kmeans_step, frame=[0, 50],181n_clusters=[min_clusters, max_clusters])182183184def plot_image_components(x, coefficients=None, mean=0, components=None,185imshape=(8, 8), n_components=6, fontsize=12):186if coefficients is None:187coefficients = x188189if components is None:190components = np.eye(len(coefficients), len(x))191192mean = np.zeros_like(x) + mean193194195fig = plt.figure(figsize=(1.2 * (5 + n_components), 1.2 * 2))196g = plt.GridSpec(2, 5 + n_components, hspace=0.3)197198def show(i, j, x, title=None):199ax = fig.add_subplot(g[i, j], xticks=[], yticks=[])200ax.imshow(x.reshape(imshape), interpolation='nearest')201if title:202ax.set_title(title, fontsize=fontsize)203204show(slice(2), slice(2), x, "True")205206approx = mean.copy()207show(0, 2, np.zeros_like(x) + mean, r'$\mu$')208show(1, 2, approx, r'$1 \cdot \mu$')209210for i in range(0, n_components):211approx = approx + coefficients[i] * components[i]212show(0, i + 3, components[i], r'$c_{0}$'.format(i + 1))213show(1, i + 3, approx,214r"${0:.2f} \cdot c_{1}$".format(coefficients[i], i + 1))215plt.gca().text(0, 1.05, '$+$', ha='right', va='bottom',216transform=plt.gca().transAxes, fontsize=fontsize)217218show(slice(2), slice(-2, None), approx, "Approx")219220221def plot_pca_interactive(data, n_components=6):222from sklearn.decomposition import PCA223from IPython.html.widgets import interact224225pca = PCA(n_components=n_components)226Xproj = pca.fit_transform(data)227228def show_decomp(i=0):229plot_image_components(data[i], Xproj[i],230pca.mean_, pca.components_)231232interact(show_decomp, i=(0, data.shape[0] - 1));233234235