📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np1import matplotlib.pyplot as plt2from .plot_helpers import cm2, cm3, discrete_scatter34def _call_classifier_chunked(classifier_pred_or_decide, X):5# The chunk_size is used to chunk the large arrays to work with x866# memory models that are restricted to < 2 GB in memory allocation. The7# chunk_size value used here is based on a measurement with the8# MLPClassifier using the following parameters:9# MLPClassifier(solver='lbfgs', random_state=0,10# hidden_layer_sizes=[1000,1000,1000])11# by reducing the value it is possible to trade in time for memory.12# It is possible to chunk the array as the calculations are independent of13# each other.14# Note: an intermittent version made a distinction between15# 32- and 64 bit architectures avoiding the chunking. Testing revealed16# that even on 64 bit architectures the chunking increases the17# performance by a factor of 3-5, largely due to the avoidance of memory18# swapping.19chunk_size = 100002021# We use a list to collect all result chunks22Y_result_chunks = []2324# Call the classifier in chunks.25for x_chunk in np.array_split(X, np.arange(chunk_size, X.shape[0],26chunk_size, dtype=np.int32),27axis=0):28Y_result_chunks.append(classifier_pred_or_decide(x_chunk))2930return np.concatenate(Y_result_chunks)313233def plot_2d_classification(classifier, X, fill=False, ax=None, eps=None,34alpha=1, cm=cm3):35# multiclass36if eps is None:37eps = X.std() / 2.3839if ax is None:40ax = plt.gca()4142x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps43y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps44xx = np.linspace(x_min, x_max, 1000)45yy = np.linspace(y_min, y_max, 1000)4647X1, X2 = np.meshgrid(xx, yy)48X_grid = np.c_[X1.ravel(), X2.ravel()]49decision_values = classifier.predict(X_grid)50ax.imshow(decision_values.reshape(X1.shape), extent=(x_min, x_max,51y_min, y_max),52aspect='auto', origin='lower', alpha=alpha, cmap=cm)53ax.set_xlim(x_min, x_max)54ax.set_ylim(y_min, y_max)55ax.set_xticks(())56ax.set_yticks(())575859def plot_2d_scores(classifier, X, ax=None, eps=None, alpha=1, cm="viridis",60function=None):61# binary with fill62if eps is None:63eps = X.std() / 2.6465if ax is None:66ax = plt.gca()6768x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps69y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps70xx = np.linspace(x_min, x_max, 100)71yy = np.linspace(y_min, y_max, 100)7273X1, X2 = np.meshgrid(xx, yy)74X_grid = np.c_[X1.ravel(), X2.ravel()]75if function is None:76function = getattr(classifier, "decision_function",77getattr(classifier, "predict_proba"))78else:79function = getattr(classifier, function)80decision_values = function(X_grid)81if decision_values.ndim > 1 and decision_values.shape[1] > 1:82# predict_proba83decision_values = decision_values[:, 1]84grr = ax.imshow(decision_values.reshape(X1.shape),85extent=(x_min, x_max, y_min, y_max), aspect='auto',86origin='lower', alpha=alpha, cmap=cm)8788ax.set_xlim(x_min, x_max)89ax.set_ylim(y_min, y_max)90ax.set_xticks(())91ax.set_yticks(())92return grr939495def plot_2d_separator(classifier, X, fill=False, ax=None, eps=None, alpha=1,96cm=cm2, linewidth=None, threshold=None,97linestyle="solid"):98# binary?99if eps is None:100eps = X.std() / 2.101102if ax is None:103ax = plt.gca()104105x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps106y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps107xx = np.linspace(x_min, x_max, 1000)108yy = np.linspace(y_min, y_max, 1000)109110X1, X2 = np.meshgrid(xx, yy)111X_grid = np.c_[X1.ravel(), X2.ravel()]112if hasattr(classifier, "decision_function"):113decision_values = _call_classifier_chunked(classifier.decision_function,114X_grid)115levels = [0] if threshold is None else [threshold]116fill_levels = [decision_values.min()] + levels + [117decision_values.max()]118else:119# no decision_function120decision_values = _call_classifier_chunked(classifier.predict_proba,121X_grid)[:, 1]122levels = [.5] if threshold is None else [threshold]123fill_levels = [0] + levels + [1]124if fill:125ax.contourf(X1, X2, decision_values.reshape(X1.shape),126levels=fill_levels, alpha=alpha, cmap=cm)127else:128ax.contour(X1, X2, decision_values.reshape(X1.shape), levels=levels,129colors="black", alpha=alpha, linewidths=linewidth,130linestyles=linestyle, zorder=5)131132ax.set_xlim(x_min, x_max)133ax.set_ylim(y_min, y_max)134ax.set_xticks(())135ax.set_yticks(())136137138if __name__ == '__main__':139from sklearn.datasets import make_blobs140from sklearn.linear_model import LogisticRegression141X, y = make_blobs(centers=2, random_state=42)142clf = LogisticRegression(solver='lbfgs').fit(X, y)143plot_2d_separator(clf, X, fill=True)144discrete_scatter(X[:, 0], X[:, 1], y)145plt.show()146147148