Path: blob/master/examples/keras_recipes/tfrecord.py
3507 views
"""1Title: How to train a Keras model on TFRecord files2Author: Amy MiHyun Jang3Date created: 2020/07/294Last modified: 2020/08/075Description: Loading TFRecords for computer vision models.6Accelerator: TPU7"""89"""10## Introduction + Set Up1112TFRecords store a sequence of binary records, read linearly. They are useful format for13storing data because they can be read efficiently. Learn more about TFRecords14[here](https://www.tensorflow.org/tutorials/load_data/tfrecord).1516We'll explore how we can easily load in TFRecords for our melanoma classifier.17"""1819import tensorflow as tf20from functools import partial21import matplotlib.pyplot as plt2223try:24tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()25print("Device:", tpu.master())26strategy = tf.distribute.TPUStrategy(tpu)27except:28strategy = tf.distribute.get_strategy()29print("Number of replicas:", strategy.num_replicas_in_sync)3031"""32We want a bigger batch size as our data is not balanced.3334"""3536AUTOTUNE = tf.data.AUTOTUNE37GCS_PATH = "gs://kds-b38ce1b823c3ae623f5691483dbaa0f0363f04b0d6a90b63cf69946e"38BATCH_SIZE = 6439IMAGE_SIZE = [1024, 1024]4041"""42## Load the data43"""4445FILENAMES = tf.io.gfile.glob(GCS_PATH + "/tfrecords/train*.tfrec")46split_ind = int(0.9 * len(FILENAMES))47TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:]4849TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + "/tfrecords/test*.tfrec")50print("Train TFRecord Files:", len(TRAINING_FILENAMES))51print("Validation TFRecord Files:", len(VALID_FILENAMES))52print("Test TFRecord Files:", len(TEST_FILENAMES))5354"""55### Decoding the data5657The images have to be converted to tensors so that it will be a valid input in our model.58As images utilize an RBG scale, we specify 3 channels.5960We also reshape our data so that all of the images will be the same shape.61"""626364def decode_image(image):65image = tf.image.decode_jpeg(image, channels=3)66image = tf.cast(image, tf.float32)67image = tf.reshape(image, [*IMAGE_SIZE, 3])68return image697071"""72As we load in our data, we need both our `X` and our `Y`. The X is our image; the model73will find features and patterns in our image dataset. We want to predict Y, the74probability that the lesion in the image is malignant. We will to through our TFRecords75and parse out the image and the target values.76"""777879def read_tfrecord(example, labeled):80tfrecord_format = (81{82"image": tf.io.FixedLenFeature([], tf.string),83"target": tf.io.FixedLenFeature([], tf.int64),84}85if labeled86else {87"image": tf.io.FixedLenFeature([], tf.string),88}89)90example = tf.io.parse_single_example(example, tfrecord_format)91image = decode_image(example["image"])92if labeled:93label = tf.cast(example["target"], tf.int32)94return image, label95return image969798"""99### Define loading methods100101Our dataset is not ordered in any meaningful way, so the order can be ignored when102loading our dataset. By ignoring the order and reading files as soon as they come in, it103will take a shorter time to load the data.104"""105106107def load_dataset(filenames, labeled=True):108ignore_order = tf.data.Options()109ignore_order.experimental_deterministic = False # disable order, increase speed110dataset = tf.data.TFRecordDataset(111filenames112) # automatically interleaves reads from multiple files113dataset = dataset.with_options(114ignore_order115) # uses data as soon as it streams in, rather than in its original order116dataset = dataset.map(117partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE118)119# returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False120return dataset121122123"""124We define the following function to get our different datasets.125"""126127128def get_dataset(filenames, labeled=True):129dataset = load_dataset(filenames, labeled=labeled)130dataset = dataset.shuffle(2048)131dataset = dataset.prefetch(buffer_size=AUTOTUNE)132dataset = dataset.batch(BATCH_SIZE)133return dataset134135136"""137### Visualize input images138"""139140train_dataset = get_dataset(TRAINING_FILENAMES)141valid_dataset = get_dataset(VALID_FILENAMES)142test_dataset = get_dataset(TEST_FILENAMES, labeled=False)143144image_batch, label_batch = next(iter(train_dataset))145146147def show_batch(image_batch, label_batch):148plt.figure(figsize=(10, 10))149for n in range(25):150ax = plt.subplot(5, 5, n + 1)151plt.imshow(image_batch[n] / 255.0)152if label_batch[n]:153plt.title("MALIGNANT")154else:155plt.title("BENIGN")156plt.axis("off")157158159show_batch(image_batch.numpy(), label_batch.numpy())160161"""162## Building our model163"""164165"""166### Define callbacks167168The following function allows for the model to change the learning rate as it runs each169epoch.170171We can use callbacks to stop training when there are no improvements in the model. At the172end of the training process, the model will restore the weights of its best iteration.173"""174175initial_learning_rate = 0.01176lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(177initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True178)179180checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(181"melanoma_model.h5", save_best_only=True182)183184early_stopping_cb = tf.keras.callbacks.EarlyStopping(185patience=10, restore_best_weights=True186)187188"""189### Build our base model190191Transfer learning is a great way to reap the benefits of a well-trained model without192having the train the model ourselves. For this notebook, we want to import the Xception193model. A more in-depth analysis of transfer learning can be found194[here](https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/).195196We do not want our metric to be ```accuracy``` because our data is imbalanced. For our197example, we will be looking at the area under a ROC curve.198"""199200201def make_model():202base_model = tf.keras.applications.Xception(203input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet"204)205206base_model.trainable = False207208inputs = tf.keras.layers.Input([*IMAGE_SIZE, 3])209x = tf.keras.applications.xception.preprocess_input(inputs)210x = base_model(x)211x = tf.keras.layers.GlobalAveragePooling2D()(x)212x = tf.keras.layers.Dense(8, activation="relu")(x)213x = tf.keras.layers.Dropout(0.7)(x)214outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)215216model = tf.keras.Model(inputs=inputs, outputs=outputs)217218model.compile(219optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),220loss="binary_crossentropy",221metrics=tf.keras.metrics.AUC(name="auc"),222)223224return model225226227"""228## Train the model229"""230231with strategy.scope():232model = make_model()233234history = model.fit(235train_dataset,236epochs=2,237validation_data=valid_dataset,238callbacks=[checkpoint_cb, early_stopping_cb],239)240241"""242## Predict results243244We'll use our model to predict results for our test dataset images. Values closer to `0`245are more likely to be benign and values closer to `1` are more likely to be malignant.246"""247248249def show_batch_predictions(image_batch):250plt.figure(figsize=(10, 10))251for n in range(25):252ax = plt.subplot(5, 5, n + 1)253plt.imshow(image_batch[n] / 255.0)254img_array = tf.expand_dims(image_batch[n], axis=0)255plt.title(model.predict(img_array)[0])256plt.axis("off")257258259image_batch = next(iter(test_dataset))260261show_batch_predictions(image_batch)262263264