📚 The CoCalc Library - books, templates and other resources
License: OTHER
import numpy as np1from sklearn.datasets import make_blobs2from sklearn.tree import export_graphviz3import matplotlib.pyplot as plt4from .plot_2d_separator import (plot_2d_separator, plot_2d_classification,5plot_2d_scores)6from .plot_helpers import cm2 as cm, discrete_scatter789def visualize_coefficients(coefficients, feature_names, n_top_features=25):10"""Visualize coefficients of a linear model.1112Parameters13----------14coefficients : nd-array, shape (n_features,)15Model coefficients.1617feature_names : list or nd-array of strings, shape (n_features,)18Feature names for labeling the coefficients.1920n_top_features : int, default=2521How many features to show. The function will show the largest (most22positive) and smallest (most negative) n_top_features coefficients,23for a total of 2 * n_top_features coefficients.24"""25coefficients = coefficients.squeeze()26if coefficients.ndim > 1:27# this is not a row or column vector28raise ValueError("coeffients must be 1d array or column vector, got"29" shape {}".format(coefficients.shape))30coefficients = coefficients.ravel()3132if len(coefficients) != len(feature_names):33raise ValueError("Number of coefficients {} doesn't match number of"34"feature names {}.".format(len(coefficients),35len(feature_names)))36# get coefficients with large absolute values37coef = coefficients.ravel()38positive_coefficients = np.argsort(coef)[-n_top_features:]39negative_coefficients = np.argsort(coef)[:n_top_features]40interesting_coefficients = np.hstack([negative_coefficients,41positive_coefficients])42# plot them43plt.figure(figsize=(15, 5))44colors = [cm(1) if c < 0 else cm(0)45for c in coef[interesting_coefficients]]46plt.bar(np.arange(2 * n_top_features), coef[interesting_coefficients],47color=colors)48feature_names = np.array(feature_names)49plt.subplots_adjust(bottom=0.3)50plt.xticks(np.arange(1, 1 + 2 * n_top_features),51feature_names[interesting_coefficients], rotation=60,52ha="right")53plt.ylabel("Coefficient magnitude")54plt.xlabel("Feature")555657def heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap=None,58vmin=None, vmax=None, ax=None, fmt="%0.2f"):59if ax is None:60ax = plt.gca()61# plot the mean cross-validation scores62img = ax.pcolor(values, cmap=cmap, vmin=vmin, vmax=vmax)63img.update_scalarmappable()64ax.set_xlabel(xlabel)65ax.set_ylabel(ylabel)66ax.set_xticks(np.arange(len(xticklabels)) + .5)67ax.set_yticks(np.arange(len(yticklabels)) + .5)68ax.set_xticklabels(xticklabels)69ax.set_yticklabels(yticklabels)70ax.set_aspect(1)7172for p, color, value in zip(img.get_paths(), img.get_facecolors(),73img.get_array()):74x, y = p.vertices[:-2, :].mean(0)75if np.mean(color[:3]) > 0.5:76c = 'k'77else:78c = 'w'79ax.text(x, y, fmt % value, color=c, ha="center", va="center")80return img818283def make_handcrafted_dataset():84# a carefully hand-designed dataset lol85X, y = make_blobs(centers=2, random_state=4, n_samples=30)86y[np.array([7, 27])] = 087mask = np.ones(len(X), dtype=np.bool)88mask[np.array([0, 1, 5, 26])] = 089X, y = X[mask], y[mask]90return X, y919293def print_topics(topics, feature_names, sorting, topics_per_chunk=6,94n_words=20):95for i in range(0, len(topics), topics_per_chunk):96# for each chunk:97these_topics = topics[i: i + topics_per_chunk]98# maybe we have less than topics_per_chunk left99len_this_chunk = len(these_topics)100# print topic headers101print(("topic {:<8}" * len_this_chunk).format(*these_topics))102print(("-------- {0:<5}" * len_this_chunk).format(""))103# print top n_words frequent words104for i in range(n_words):105try:106print(("{:<14}" * len_this_chunk).format(107*feature_names[sorting[these_topics, i]]))108except:109pass110print("\n")111112113def get_tree(tree, **kwargs):114try:115# python3116from io import StringIO117except ImportError:118# python2119from StringIO import StringIO120f = StringIO()121export_graphviz(tree, f, **kwargs)122import graphviz123return graphviz.Source(f.getvalue())124125__all__ = ['plot_2d_separator', 'plot_2d_classification', 'plot_2d_scores',126'cm', 'visualize_coefficients', 'print_topics', 'heatmap',127'discrete_scatter']128129130