📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np12from sklearn.datasets import make_blobs3from sklearn.cluster import KMeans4from sklearn.metrics import pairwise_distances5import matplotlib.pyplot as plt6import matplotlib as mpl7from cycler import cycler89from .tools import discrete_scatter10from .plot_2d_separator import plot_2d_classification11from .plot_helpers import cm3121314def plot_kmeans_algorithm():1516X, y = make_blobs(random_state=1)17# we don't want cyan in there18with mpl.rc_context(rc={'axes.prop_cycle': cycler('color', ['#0000aa',19'#ff2020',20'#50ff50'])}):21fig, axes = plt.subplots(3, 3, figsize=(10, 8), subplot_kw={'xticks': (), 'yticks': ()})22axes = axes.ravel()23axes[0].set_title("Input data")24discrete_scatter(X[:, 0], X[:, 1], ax=axes[0], markers=['o'], c='w')2526axes[1].set_title("Initialization")27init = X[:3, :]28discrete_scatter(X[:, 0], X[:, 1], ax=axes[1], markers=['o'], c='w')29discrete_scatter(init[:, 0], init[:, 1], [0, 1, 2], ax=axes[1],30markers=['^'], markeredgewidth=2)3132axes[2].set_title("Assign Points (1)")33km = KMeans(n_clusters=3, init=init, max_iter=1, n_init=1).fit(X)34centers = km.cluster_centers_35# need to compute labels by hand. scikit-learn does two e-steps for max_iter=136# (and it's totally my fault)37labels = np.argmin(pairwise_distances(init, X), axis=0)38discrete_scatter(X[:, 0], X[:, 1], labels, markers=['o'],39ax=axes[2])40discrete_scatter(init[:, 0], init[:, 1], [0, 1, 2],41ax=axes[2], markers=['^'], markeredgewidth=2)4243axes[3].set_title("Recompute Centers (1)")44discrete_scatter(X[:, 0], X[:, 1], labels, markers=['o'],45ax=axes[3])46discrete_scatter(centers[:, 0], centers[:, 1], [0, 1, 2],47ax=axes[3], markers=['^'], markeredgewidth=2)4849axes[4].set_title("Reassign Points (2)")50km = KMeans(n_clusters=3, init=init, max_iter=1, n_init=1).fit(X)51labels = km.labels_52discrete_scatter(X[:, 0], X[:, 1], labels, markers=['o'],53ax=axes[4])54discrete_scatter(centers[:, 0], centers[:, 1], [0, 1, 2],55ax=axes[4], markers=['^'], markeredgewidth=2)5657km = KMeans(n_clusters=3, init=init, max_iter=2, n_init=1).fit(X)58axes[5].set_title("Recompute Centers (2)")59centers = km.cluster_centers_60discrete_scatter(X[:, 0], X[:, 1], labels, markers=['o'],61ax=axes[5])62discrete_scatter(centers[:, 0], centers[:, 1], [0, 1, 2],63ax=axes[5], markers=['^'], markeredgewidth=2)6465axes[6].set_title("Reassign Points (3)")66labels = km.labels_67discrete_scatter(X[:, 0], X[:, 1], labels, markers=['o'],68ax=axes[6])69markers = discrete_scatter(centers[:, 0], centers[:, 1], [0, 1, 2],70ax=axes[6], markers=['^'],71markeredgewidth=2)7273axes[7].set_title("Recompute Centers (3)")74km = KMeans(n_clusters=3, init=init, max_iter=3, n_init=1).fit(X)75centers = km.cluster_centers_76discrete_scatter(X[:, 0], X[:, 1], labels, markers=['o'],77ax=axes[7])78discrete_scatter(centers[:, 0], centers[:, 1], [0, 1, 2],79ax=axes[7], markers=['^'], markeredgewidth=2)80axes[8].set_axis_off()81axes[8].legend(markers, ["Cluster 0", "Cluster 1", "Cluster 2"], loc='best')828384def plot_kmeans_boundaries():85X, y = make_blobs(random_state=1)86init = X[:3, :]87km = KMeans(n_clusters=3, init=init, max_iter=2, n_init=1).fit(X)88discrete_scatter(X[:, 0], X[:, 1], km.labels_, markers=['o'])89discrete_scatter(km.cluster_centers_[:, 0], km.cluster_centers_[:, 1],90[0, 1, 2], markers=['^'], markeredgewidth=2)91plot_2d_classification(km, X, cm=cm3, alpha=.4)929394def plot_kmeans_faces(km, pca, X_pca, X_people, y_people, target_names):95n_clusters = 1096image_shape = (87, 65)97fig, axes = plt.subplots(n_clusters, 11, subplot_kw={'xticks': (), 'yticks': ()},98figsize=(10, 15), gridspec_kw={"hspace": .3})99100for cluster in range(n_clusters):101center = km.cluster_centers_[cluster]102mask = km.labels_ == cluster103dists = np.sum((X_pca - center) ** 2, axis=1)104dists[~mask] = np.inf105inds = np.argsort(dists)[:5]106dists[~mask] = -np.inf107inds = np.r_[inds, np.argsort(dists)[-5:]]108axes[cluster, 0].imshow(pca.inverse_transform(center).reshape(image_shape), vmin=0, vmax=1)109for image, label, asdf, ax in zip(X_people[inds], y_people[inds],110km.labels_[inds], axes[cluster, 1:]):111ax.imshow(image.reshape(image_shape), vmin=0, vmax=1)112ax.set_title("%s" % (target_names[label].split()[-1]), fontdict={'fontsize': 9})113114# add some boxes to illustrate which are similar and which dissimilar115rec = plt.Rectangle([-5, -30], 73, 1295, fill=False, lw=2)116rec = axes[0, 0].add_patch(rec)117rec.set_clip_on(False)118axes[0, 0].text(0, -40, "Center")119120rec = plt.Rectangle([-5, -30], 385, 1295, fill=False, lw=2)121rec = axes[0, 1].add_patch(rec)122rec.set_clip_on(False)123axes[0, 1].text(0, -40, "Close to center")124125rec = plt.Rectangle([-5, -30], 385, 1295, fill=False, lw=2)126rec = axes[0, 6].add_patch(rec)127rec.set_clip_on(False)128axes[0, 6].text(0, -40, "Far from center")129130131