Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
import numpy as np
2
import matplotlib.pyplot as plt
3
4
from .tools import plot_2d_separator, plot_2d_scores, cm, discrete_scatter
5
from .plot_helpers import ReBl
6
7
8
def plot_confusion_matrix_illustration():
9
plt.figure(figsize=(8, 8))
10
confusion = np.array([[401, 2], [8, 39]])
11
plt.text(0.40, .7, confusion[0, 0], size=70, horizontalalignment='right')
12
plt.text(0.40, .2, confusion[1, 0], size=70, horizontalalignment='right')
13
plt.text(.90, .7, confusion[0, 1], size=70, horizontalalignment='right')
14
plt.text(.90, 0.2, confusion[1, 1], size=70, horizontalalignment='right')
15
plt.xticks([.25, .75], ["predicted 'not nine'", "predicted 'nine'"], size=20)
16
plt.yticks([.25, .75], ["true 'nine'", "true 'not nine'"], size=20)
17
plt.plot([.5, .5], [0, 1], '--', c='k')
18
plt.plot([0, 1], [.5, .5], '--', c='k')
19
20
plt.xlim(0, 1)
21
plt.ylim(0, 1)
22
23
24
def plot_binary_confusion_matrix():
25
plt.text(0.45, .6, "TN", size=100, horizontalalignment='right')
26
plt.text(0.45, .1, "FN", size=100, horizontalalignment='right')
27
plt.text(.95, .6, "FP", size=100, horizontalalignment='right')
28
plt.text(.95, 0.1, "TP", size=100, horizontalalignment='right')
29
plt.xticks([.25, .75], ["predicted negative", "predicted positive"], size=15)
30
plt.yticks([.25, .75], ["positive class", "negative class"], size=15)
31
plt.plot([.5, .5], [0, 1], '--', c='k')
32
plt.plot([0, 1], [.5, .5], '--', c='k')
33
34
plt.xlim(0, 1)
35
plt.ylim(0, 1)
36
37
38
def plot_decision_threshold():
39
from sklearn.datasets import make_blobs
40
from sklearn.svm import SVC
41
from sklearn.model_selection import train_test_split
42
43
X, y = make_blobs(n_samples=(400, 50), cluster_std=[7.0, 2],
44
random_state=22)
45
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
46
47
fig, axes = plt.subplots(2, 3, figsize=(15, 8), subplot_kw={'xticks': (), 'yticks': ()})
48
plt.suptitle("decision_threshold")
49
axes[0, 0].set_title("training data")
50
discrete_scatter(X_train[:, 0], X_train[:, 1], y_train, ax=axes[0, 0])
51
52
svc = SVC(gamma=.05).fit(X_train, y_train)
53
axes[0, 1].set_title("decision with threshold 0")
54
discrete_scatter(X_train[:, 0], X_train[:, 1], y_train, ax=axes[0, 1])
55
plot_2d_scores(svc, X_train, function="decision_function", alpha=.7,
56
ax=axes[0, 1], cm=ReBl)
57
plot_2d_separator(svc, X_train, linewidth=3, ax=axes[0, 1])
58
axes[0, 2].set_title("decision with threshold -0.8")
59
discrete_scatter(X_train[:, 0], X_train[:, 1], y_train, ax=axes[0, 2])
60
plot_2d_separator(svc, X_train, linewidth=3, ax=axes[0, 2], threshold=-.8)
61
plot_2d_scores(svc, X_train, function="decision_function", alpha=.7,
62
ax=axes[0, 2], cm=ReBl)
63
64
axes[1, 0].set_axis_off()
65
66
mask = np.abs(X_train[:, 1] - 7) < 5
67
bla = np.sum(mask)
68
69
line = np.linspace(X_train.min(), X_train.max(), 100)
70
axes[1, 1].set_title("Cross-section with threshold 0")
71
axes[1, 1].plot(line, svc.decision_function(np.c_[line, 10 * np.ones(100)]), c='k')
72
dec = svc.decision_function(np.c_[line, 10 * np.ones(100)])
73
contour = (dec > 0).reshape(1, -1).repeat(10, axis=0)
74
axes[1, 1].contourf(line, np.linspace(-1.5, 1.5, 10), contour, alpha=0.4, cmap=cm)
75
discrete_scatter(X_train[mask, 0], np.zeros(bla), y_train[mask], ax=axes[1, 1])
76
axes[1, 1].set_xlim(X_train.min(), X_train.max())
77
axes[1, 1].set_ylim(-1.5, 1.5)
78
axes[1, 1].set_xticks(())
79
axes[1, 1].set_ylabel("Decision value")
80
81
contour2 = (dec > -.8).reshape(1, -1).repeat(10, axis=0)
82
axes[1, 2].set_title("Cross-section with threshold -0.8")
83
axes[1, 2].contourf(line, np.linspace(-1.5, 1.5, 10), contour2, alpha=0.4, cmap=cm)
84
discrete_scatter(X_train[mask, 0], np.zeros(bla), y_train[mask], alpha=.1, ax=axes[1, 2])
85
axes[1, 2].plot(line, svc.decision_function(np.c_[line, 10 * np.ones(100)]), c='k')
86
axes[1, 2].set_xlim(X_train.min(), X_train.max())
87
axes[1, 2].set_ylim(-1.5, 1.5)
88
axes[1, 2].set_xticks(())
89
axes[1, 2].set_ylabel("Decision value")
90
axes[1, 0].legend(['negative class', 'positive class'])
91
92