Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132923 views
License: OTHER
1
import numpy as np
2
from sklearn.datasets import make_blobs
3
from sklearn.tree import export_graphviz
4
import matplotlib.pyplot as plt
5
from .plot_2d_separator import (plot_2d_separator, plot_2d_classification,
6
plot_2d_scores)
7
from .plot_helpers import cm2 as cm, discrete_scatter
8
9
10
def visualize_coefficients(coefficients, feature_names, n_top_features=25):
11
"""Visualize coefficients of a linear model.
12
13
Parameters
14
----------
15
coefficients : nd-array, shape (n_features,)
16
Model coefficients.
17
18
feature_names : list or nd-array of strings, shape (n_features,)
19
Feature names for labeling the coefficients.
20
21
n_top_features : int, default=25
22
How many features to show. The function will show the largest (most
23
positive) and smallest (most negative) n_top_features coefficients,
24
for a total of 2 * n_top_features coefficients.
25
"""
26
coefficients = coefficients.squeeze()
27
if coefficients.ndim > 1:
28
# this is not a row or column vector
29
raise ValueError("coeffients must be 1d array or column vector, got"
30
" shape {}".format(coefficients.shape))
31
coefficients = coefficients.ravel()
32
33
if len(coefficients) != len(feature_names):
34
raise ValueError("Number of coefficients {} doesn't match number of"
35
"feature names {}.".format(len(coefficients),
36
len(feature_names)))
37
# get coefficients with large absolute values
38
coef = coefficients.ravel()
39
positive_coefficients = np.argsort(coef)[-n_top_features:]
40
negative_coefficients = np.argsort(coef)[:n_top_features]
41
interesting_coefficients = np.hstack([negative_coefficients,
42
positive_coefficients])
43
# plot them
44
plt.figure(figsize=(15, 5))
45
colors = [cm(1) if c < 0 else cm(0)
46
for c in coef[interesting_coefficients]]
47
plt.bar(np.arange(2 * n_top_features), coef[interesting_coefficients],
48
color=colors)
49
feature_names = np.array(feature_names)
50
plt.subplots_adjust(bottom=0.3)
51
plt.xticks(np.arange(1, 1 + 2 * n_top_features),
52
feature_names[interesting_coefficients], rotation=60,
53
ha="right")
54
plt.ylabel("Coefficient magnitude")
55
plt.xlabel("Feature")
56
57
58
def heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap=None,
59
vmin=None, vmax=None, ax=None, fmt="%0.2f"):
60
if ax is None:
61
ax = plt.gca()
62
# plot the mean cross-validation scores
63
img = ax.pcolor(values, cmap=cmap, vmin=vmin, vmax=vmax)
64
img.update_scalarmappable()
65
ax.set_xlabel(xlabel)
66
ax.set_ylabel(ylabel)
67
ax.set_xticks(np.arange(len(xticklabels)) + .5)
68
ax.set_yticks(np.arange(len(yticklabels)) + .5)
69
ax.set_xticklabels(xticklabels)
70
ax.set_yticklabels(yticklabels)
71
ax.set_aspect(1)
72
73
for p, color, value in zip(img.get_paths(), img.get_facecolors(),
74
img.get_array()):
75
x, y = p.vertices[:-2, :].mean(0)
76
if np.mean(color[:3]) > 0.5:
77
c = 'k'
78
else:
79
c = 'w'
80
ax.text(x, y, fmt % value, color=c, ha="center", va="center")
81
return img
82
83
84
def make_handcrafted_dataset():
85
# a carefully hand-designed dataset lol
86
X, y = make_blobs(centers=2, random_state=4, n_samples=30)
87
y[np.array([7, 27])] = 0
88
mask = np.ones(len(X), dtype=np.bool)
89
mask[np.array([0, 1, 5, 26])] = 0
90
X, y = X[mask], y[mask]
91
return X, y
92
93
94
def print_topics(topics, feature_names, sorting, topics_per_chunk=6,
95
n_words=20):
96
for i in range(0, len(topics), topics_per_chunk):
97
# for each chunk:
98
these_topics = topics[i: i + topics_per_chunk]
99
# maybe we have less than topics_per_chunk left
100
len_this_chunk = len(these_topics)
101
# print topic headers
102
print(("topic {:<8}" * len_this_chunk).format(*these_topics))
103
print(("-------- {0:<5}" * len_this_chunk).format(""))
104
# print top n_words frequent words
105
for i in range(n_words):
106
try:
107
print(("{:<14}" * len_this_chunk).format(
108
*feature_names[sorting[these_topics, i]]))
109
except:
110
pass
111
print("\n")
112
113
114
def get_tree(tree, **kwargs):
115
try:
116
# python3
117
from io import StringIO
118
except ImportError:
119
# python2
120
from StringIO import StringIO
121
f = StringIO()
122
export_graphviz(tree, f, **kwargs)
123
import graphviz
124
return graphviz.Source(f.getvalue())
125
126
__all__ = ['plot_2d_separator', 'plot_2d_classification', 'plot_2d_scores',
127
'cm', 'visualize_coefficients', 'print_topics', 'heatmap',
128
'discrete_scatter']
129
130