Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/mnist_convnet.py
3507 views
1
"""
2
Title: Simple MNIST convnet
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2015/06/19
5
Last modified: 2020/04/21
6
Description: A simple convnet that achieves ~99% test accuracy on MNIST.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import numpy as np
15
import keras
16
from keras import layers
17
18
"""
19
## Prepare the data
20
"""
21
22
# Model / data parameters
23
num_classes = 10
24
input_shape = (28, 28, 1)
25
26
# Load the data and split it between train and test sets
27
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
28
29
# Scale images to the [0, 1] range
30
x_train = x_train.astype("float32") / 255
31
x_test = x_test.astype("float32") / 255
32
# Make sure images have shape (28, 28, 1)
33
x_train = np.expand_dims(x_train, -1)
34
x_test = np.expand_dims(x_test, -1)
35
print("x_train shape:", x_train.shape)
36
print(x_train.shape[0], "train samples")
37
print(x_test.shape[0], "test samples")
38
39
40
# convert class vectors to binary class matrices
41
y_train = keras.utils.to_categorical(y_train, num_classes)
42
y_test = keras.utils.to_categorical(y_test, num_classes)
43
44
"""
45
## Build the model
46
"""
47
48
model = keras.Sequential(
49
[
50
keras.Input(shape=input_shape),
51
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
52
layers.MaxPooling2D(pool_size=(2, 2)),
53
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
54
layers.MaxPooling2D(pool_size=(2, 2)),
55
layers.Flatten(),
56
layers.Dropout(0.5),
57
layers.Dense(num_classes, activation="softmax"),
58
]
59
)
60
61
model.summary()
62
63
"""
64
## Train the model
65
"""
66
67
batch_size = 128
68
epochs = 15
69
70
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
71
72
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
73
74
"""
75
## Evaluate the trained model
76
"""
77
78
score = model.evaluate(x_test, y_test, verbose=0)
79
print("Test loss:", score[0])
80
print("Test accuracy:", score[1])
81
82