Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/tfrecord.py
3507 views
1
"""
2
Title: How to train a Keras model on TFRecord files
3
Author: Amy MiHyun Jang
4
Date created: 2020/07/29
5
Last modified: 2020/08/07
6
Description: Loading TFRecords for computer vision models.
7
Accelerator: TPU
8
"""
9
10
"""
11
## Introduction + Set Up
12
13
TFRecords store a sequence of binary records, read linearly. They are useful format for
14
storing data because they can be read efficiently. Learn more about TFRecords
15
[here](https://www.tensorflow.org/tutorials/load_data/tfrecord).
16
17
We'll explore how we can easily load in TFRecords for our melanoma classifier.
18
"""
19
20
import tensorflow as tf
21
from functools import partial
22
import matplotlib.pyplot as plt
23
24
try:
25
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
26
print("Device:", tpu.master())
27
strategy = tf.distribute.TPUStrategy(tpu)
28
except:
29
strategy = tf.distribute.get_strategy()
30
print("Number of replicas:", strategy.num_replicas_in_sync)
31
32
"""
33
We want a bigger batch size as our data is not balanced.
34
35
"""
36
37
AUTOTUNE = tf.data.AUTOTUNE
38
GCS_PATH = "gs://kds-b38ce1b823c3ae623f5691483dbaa0f0363f04b0d6a90b63cf69946e"
39
BATCH_SIZE = 64
40
IMAGE_SIZE = [1024, 1024]
41
42
"""
43
## Load the data
44
"""
45
46
FILENAMES = tf.io.gfile.glob(GCS_PATH + "/tfrecords/train*.tfrec")
47
split_ind = int(0.9 * len(FILENAMES))
48
TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:]
49
50
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + "/tfrecords/test*.tfrec")
51
print("Train TFRecord Files:", len(TRAINING_FILENAMES))
52
print("Validation TFRecord Files:", len(VALID_FILENAMES))
53
print("Test TFRecord Files:", len(TEST_FILENAMES))
54
55
"""
56
### Decoding the data
57
58
The images have to be converted to tensors so that it will be a valid input in our model.
59
As images utilize an RBG scale, we specify 3 channels.
60
61
We also reshape our data so that all of the images will be the same shape.
62
"""
63
64
65
def decode_image(image):
66
image = tf.image.decode_jpeg(image, channels=3)
67
image = tf.cast(image, tf.float32)
68
image = tf.reshape(image, [*IMAGE_SIZE, 3])
69
return image
70
71
72
"""
73
As we load in our data, we need both our `X` and our `Y`. The X is our image; the model
74
will find features and patterns in our image dataset. We want to predict Y, the
75
probability that the lesion in the image is malignant. We will to through our TFRecords
76
and parse out the image and the target values.
77
"""
78
79
80
def read_tfrecord(example, labeled):
81
tfrecord_format = (
82
{
83
"image": tf.io.FixedLenFeature([], tf.string),
84
"target": tf.io.FixedLenFeature([], tf.int64),
85
}
86
if labeled
87
else {
88
"image": tf.io.FixedLenFeature([], tf.string),
89
}
90
)
91
example = tf.io.parse_single_example(example, tfrecord_format)
92
image = decode_image(example["image"])
93
if labeled:
94
label = tf.cast(example["target"], tf.int32)
95
return image, label
96
return image
97
98
99
"""
100
### Define loading methods
101
102
Our dataset is not ordered in any meaningful way, so the order can be ignored when
103
loading our dataset. By ignoring the order and reading files as soon as they come in, it
104
will take a shorter time to load the data.
105
"""
106
107
108
def load_dataset(filenames, labeled=True):
109
ignore_order = tf.data.Options()
110
ignore_order.experimental_deterministic = False # disable order, increase speed
111
dataset = tf.data.TFRecordDataset(
112
filenames
113
) # automatically interleaves reads from multiple files
114
dataset = dataset.with_options(
115
ignore_order
116
) # uses data as soon as it streams in, rather than in its original order
117
dataset = dataset.map(
118
partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE
119
)
120
# returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
121
return dataset
122
123
124
"""
125
We define the following function to get our different datasets.
126
"""
127
128
129
def get_dataset(filenames, labeled=True):
130
dataset = load_dataset(filenames, labeled=labeled)
131
dataset = dataset.shuffle(2048)
132
dataset = dataset.prefetch(buffer_size=AUTOTUNE)
133
dataset = dataset.batch(BATCH_SIZE)
134
return dataset
135
136
137
"""
138
### Visualize input images
139
"""
140
141
train_dataset = get_dataset(TRAINING_FILENAMES)
142
valid_dataset = get_dataset(VALID_FILENAMES)
143
test_dataset = get_dataset(TEST_FILENAMES, labeled=False)
144
145
image_batch, label_batch = next(iter(train_dataset))
146
147
148
def show_batch(image_batch, label_batch):
149
plt.figure(figsize=(10, 10))
150
for n in range(25):
151
ax = plt.subplot(5, 5, n + 1)
152
plt.imshow(image_batch[n] / 255.0)
153
if label_batch[n]:
154
plt.title("MALIGNANT")
155
else:
156
plt.title("BENIGN")
157
plt.axis("off")
158
159
160
show_batch(image_batch.numpy(), label_batch.numpy())
161
162
"""
163
## Building our model
164
"""
165
166
"""
167
### Define callbacks
168
169
The following function allows for the model to change the learning rate as it runs each
170
epoch.
171
172
We can use callbacks to stop training when there are no improvements in the model. At the
173
end of the training process, the model will restore the weights of its best iteration.
174
"""
175
176
initial_learning_rate = 0.01
177
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
178
initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
179
)
180
181
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
182
"melanoma_model.h5", save_best_only=True
183
)
184
185
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
186
patience=10, restore_best_weights=True
187
)
188
189
"""
190
### Build our base model
191
192
Transfer learning is a great way to reap the benefits of a well-trained model without
193
having the train the model ourselves. For this notebook, we want to import the Xception
194
model. A more in-depth analysis of transfer learning can be found
195
[here](https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/).
196
197
We do not want our metric to be ```accuracy``` because our data is imbalanced. For our
198
example, we will be looking at the area under a ROC curve.
199
"""
200
201
202
def make_model():
203
base_model = tf.keras.applications.Xception(
204
input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet"
205
)
206
207
base_model.trainable = False
208
209
inputs = tf.keras.layers.Input([*IMAGE_SIZE, 3])
210
x = tf.keras.applications.xception.preprocess_input(inputs)
211
x = base_model(x)
212
x = tf.keras.layers.GlobalAveragePooling2D()(x)
213
x = tf.keras.layers.Dense(8, activation="relu")(x)
214
x = tf.keras.layers.Dropout(0.7)(x)
215
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
216
217
model = tf.keras.Model(inputs=inputs, outputs=outputs)
218
219
model.compile(
220
optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
221
loss="binary_crossentropy",
222
metrics=tf.keras.metrics.AUC(name="auc"),
223
)
224
225
return model
226
227
228
"""
229
## Train the model
230
"""
231
232
with strategy.scope():
233
model = make_model()
234
235
history = model.fit(
236
train_dataset,
237
epochs=2,
238
validation_data=valid_dataset,
239
callbacks=[checkpoint_cb, early_stopping_cb],
240
)
241
242
"""
243
## Predict results
244
245
We'll use our model to predict results for our test dataset images. Values closer to `0`
246
are more likely to be benign and values closer to `1` are more likely to be malignant.
247
"""
248
249
250
def show_batch_predictions(image_batch):
251
plt.figure(figsize=(10, 10))
252
for n in range(25):
253
ax = plt.subplot(5, 5, n + 1)
254
plt.imshow(image_batch[n] / 255.0)
255
img_array = tf.expand_dims(image_batch[n], axis=0)
256
plt.title(model.predict(img_array)[0])
257
plt.axis("off")
258
259
260
image_batch = next(iter(test_dataset))
261
262
show_batch_predictions(image_batch)
263
264