Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
import matplotlib.pyplot as plt
2
import numpy as np
3
from sklearn.svm import LinearSVC
4
from sklearn.datasets import make_blobs
5
6
from .plot_helpers import discrete_scatter
7
8
9
def plot_linear_svc_regularization():
10
X, y = make_blobs(centers=2, random_state=4, n_samples=30)
11
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
12
13
# a carefully hand-designed dataset lol
14
y[7] = 0
15
y[27] = 0
16
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
17
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
18
19
for ax, C in zip(axes, [1e-2, 10, 1e3]):
20
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
21
22
svm = LinearSVC(C=C, tol=0.00001, dual=False).fit(X, y)
23
w = svm.coef_[0]
24
a = -w[0] / w[1]
25
xx = np.linspace(6, 13)
26
yy = a * xx - (svm.intercept_[0]) / w[1]
27
ax.plot(xx, yy, c='k')
28
ax.set_xlim(x_min, x_max)
29
ax.set_ylim(y_min, y_max)
30
ax.set_xticks(())
31
ax.set_yticks(())
32
ax.set_title("C = %f" % C)
33
axes[0].legend(loc="best")
34
35
if __name__ == "__main__":
36
plot_linear_svc_regularization()
37
plt.show()
38
39