📚 The CoCalc Library - books, templates and other resources
License: OTHER
"""1Small helpers for code that is not shown in the notebooks2"""34from sklearn import neighbors, datasets, linear_model5import pylab as pl6import numpy as np7from matplotlib.colors import ListedColormap89# Create color maps for 3-class classification problem, as with iris10cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])11cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])1213def plot_iris_knn():14iris = datasets.load_iris()15X = iris.data[:, :2] # we only take the first two features. We could16# avoid this ugly slicing by using a two-dim dataset17y = iris.target1819knn = neighbors.KNeighborsClassifier(n_neighbors=5)20knn.fit(X, y)2122x_min, x_max = X[:, 0].min() - .1, X[:, 0].max() + .123y_min, y_max = X[:, 1].min() - .1, X[:, 1].max() + .124xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),25np.linspace(y_min, y_max, 100))26Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])2728# Put the result into a color plot29Z = Z.reshape(xx.shape)30pl.figure()31pl.pcolormesh(xx, yy, Z, cmap=cmap_light)3233# Plot also the training points34pl.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold)35pl.xlabel('sepal length (cm)')36pl.ylabel('sepal width (cm)')37pl.axis('tight')383940def plot_polynomial_regression():41rng = np.random.RandomState(0)42x = 2*rng.rand(100) - 14344f = lambda t: 1.2 * t**2 + .1 * t**3 - .4 * t **5 - .5 * t ** 945y = f(x) + .4 * rng.normal(size=100)4647x_test = np.linspace(-1, 1, 100)4849pl.figure()50pl.scatter(x, y, s=4)5152X = np.array([x**i for i in range(5)]).T53X_test = np.array([x_test**i for i in range(5)]).T54regr = linear_model.LinearRegression()55regr.fit(X, y)56pl.plot(x_test, regr.predict(X_test), label='4th order')5758X = np.array([x**i for i in range(10)]).T59X_test = np.array([x_test**i for i in range(10)]).T60regr = linear_model.LinearRegression()61regr.fit(X, y)62pl.plot(x_test, regr.predict(X_test), label='9th order')6364pl.legend(loc='best')65pl.axis('tight')66pl.title('Fitting a 4th and a 9th order polynomial')6768pl.figure()69pl.scatter(x, y, s=4)70pl.plot(x_test, f(x_test), label="truth")71pl.axis('tight')72pl.title('Ground truth (9th order polynomial)')7374757677