Path: blob/master/examples/keras_recipes/creating_tfrecords.py
3507 views
"""1Title: Creating TFRecords2Author: [Dimitre Oliveira](https://www.linkedin.com/in/dimitre-oliveira-7a1a0113a/)3Date created: 2021/02/274Last modified: 2023/12/205Description: Converting data to the TFRecord format.6Accelerator: GPU7"""89"""10## Introduction1112The TFRecord format is a simple format for storing a sequence of binary records.13Converting your data into TFRecord has many advantages, such as:1415- **More efficient storage**: the TFRecord data can take up less space than the original16data; it can also be partitioned into multiple files.17- **Fast I/O**: the TFRecord format can be read with parallel I/O operations, which is18useful for [TPUs](https://www.tensorflow.org/guide/tpu) or multiple hosts.19- **Self-contained files**: the TFRecord data can be read from a single source—for20example, the [COCO2017](https://cocodataset.org/) dataset originally stores data in21two folders ("images" and "annotations").2223An important use case of the TFRecord data format is training on TPUs. First, TPUs are24fast enough to benefit from optimized I/O operations. In addition, TPUs require25data to be stored remotely (e.g. on Google Cloud Storage) and using the TFRecord format26makes it easier to load the data without batch-downloading.2728Performance using the TFRecord format can be further improved if you also use29it with the [tf.data](https://www.tensorflow.org/guide/data) API.3031In this example you will learn how to convert data of different types (image, text, and32numeric) into TFRecord.3334**Reference**3536- [TFRecord and tf.train.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord)373839## Dependencies40"""4142import os4344os.environ["KERAS_BACKEND"] = "tensorflow"45import keras46import json47import pprint48import tensorflow as tf49import matplotlib.pyplot as plt5051"""52## Download the COCO2017 dataset5354We will be using the [COCO2017](https://cocodataset.org/) dataset, because it has many55different types of features, including images, floating point data, and lists.56It will serve as a good example of how to encode different features into the TFRecord57format.5859This dataset has two sets of fields: images and annotation meta-data.6061The images are a collection of JPG files and the meta-data are stored in a JSON file62which, according to the [official site](https://cocodataset.org/#format-data),63contains the following properties:6465```66id: int,67image_id: int,68category_id: int,69segmentation: RLE or [polygon], object segmentation mask70bbox: [x,y,width,height], object bounding box coordinates71area: float, area of the bounding box72iscrowd: 0 or 1, is single object or a collection73```74"""7576root_dir = "datasets"77tfrecords_dir = "tfrecords"78images_dir = os.path.join(root_dir, "val2017")79annotations_dir = os.path.join(root_dir, "annotations")80annotation_file = os.path.join(annotations_dir, "instances_val2017.json")81images_url = "http://images.cocodataset.org/zips/val2017.zip"82annotations_url = (83"http://images.cocodataset.org/annotations/annotations_trainval2017.zip"84)8586# Download image files87if not os.path.exists(images_dir):88image_zip = keras.utils.get_file(89"images.zip",90cache_dir=os.path.abspath("."),91origin=images_url,92extract=True,93)94os.remove(image_zip)9596# Download caption annotation files97if not os.path.exists(annotations_dir):98annotation_zip = keras.utils.get_file(99"captions.zip",100cache_dir=os.path.abspath("."),101origin=annotations_url,102extract=True,103)104os.remove(annotation_zip)105106print("The COCO dataset has been downloaded and extracted successfully.")107108with open(annotation_file, "r") as f:109annotations = json.load(f)["annotations"]110111print(f"Number of images: {len(annotations)}")112113"""114### Contents of the COCO2017 dataset115"""116117pprint.pprint(annotations[60])118119"""120## Parameters121122`num_samples` is the number of data samples on each TFRecord file.123124`num_tfrecords` is total number of TFRecords that we will create.125"""126127num_samples = 4096128num_tfrecords = len(annotations) // num_samples129if len(annotations) % num_samples:130num_tfrecords += 1 # add one record if there are any remaining samples131132if not os.path.exists(tfrecords_dir):133os.makedirs(tfrecords_dir) # creating TFRecords output folder134135"""136## Define TFRecords helper functions137"""138139140def image_feature(value):141"""Returns a bytes_list from a string / byte."""142return tf.train.Feature(143bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])144)145146147def bytes_feature(value):148"""Returns a bytes_list from a string / byte."""149return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))150151152def float_feature(value):153"""Returns a float_list from a float / double."""154return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))155156157def int64_feature(value):158"""Returns an int64_list from a bool / enum / int / uint."""159return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))160161162def float_feature_list(value):163"""Returns a list of float_list from a float / double."""164return tf.train.Feature(float_list=tf.train.FloatList(value=value))165166167def create_example(image, path, example):168feature = {169"image": image_feature(image),170"path": bytes_feature(path),171"area": float_feature(example["area"]),172"bbox": float_feature_list(example["bbox"]),173"category_id": int64_feature(example["category_id"]),174"id": int64_feature(example["id"]),175"image_id": int64_feature(example["image_id"]),176}177return tf.train.Example(features=tf.train.Features(feature=feature))178179180def parse_tfrecord_fn(example):181feature_description = {182"image": tf.io.FixedLenFeature([], tf.string),183"path": tf.io.FixedLenFeature([], tf.string),184"area": tf.io.FixedLenFeature([], tf.float32),185"bbox": tf.io.VarLenFeature(tf.float32),186"category_id": tf.io.FixedLenFeature([], tf.int64),187"id": tf.io.FixedLenFeature([], tf.int64),188"image_id": tf.io.FixedLenFeature([], tf.int64),189}190example = tf.io.parse_single_example(example, feature_description)191example["image"] = tf.io.decode_jpeg(example["image"], channels=3)192example["bbox"] = tf.sparse.to_dense(example["bbox"])193return example194195196"""197## Generate data in the TFRecord format198199Let's generate the COCO2017 data in the TFRecord format. The format will be200`file_{number}.tfrec` (this is optional, but including the number sequences in the file201names can make counting easier).202"""203204for tfrec_num in range(num_tfrecords):205samples = annotations[(tfrec_num * num_samples) : ((tfrec_num + 1) * num_samples)]206207with tf.io.TFRecordWriter(208tfrecords_dir + "/file_%.2i-%i.tfrec" % (tfrec_num, len(samples))209) as writer:210for sample in samples:211image_path = f"{images_dir}/{sample['image_id']:012d}.jpg"212image = tf.io.decode_jpeg(tf.io.read_file(image_path))213example = create_example(image, image_path, sample)214writer.write(example.SerializeToString())215216"""217## Explore one sample from the generated TFRecord218"""219220raw_dataset = tf.data.TFRecordDataset(f"{tfrecords_dir}/file_00-{num_samples}.tfrec")221parsed_dataset = raw_dataset.map(parse_tfrecord_fn)222223for features in parsed_dataset.take(1):224for key in features.keys():225if key != "image":226print(f"{key}: {features[key]}")227228print(f"Image shape: {features['image'].shape}")229plt.figure(figsize=(7, 7))230plt.imshow(features["image"].numpy())231plt.show()232233"""234## Train a simple model using the generated TFRecords235236Another advantage of TFRecord is that you are able to add many features to it and later237use only a few of them, in this case, we are going to use only `image` and `category_id`.238239"""240241"""242243## Define dataset helper functions244"""245246247def prepare_sample(features):248image = keras.ops.image.resize(features["image"], size=(224, 224))249return image, features["category_id"]250251252def get_dataset(filenames, batch_size):253dataset = (254tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)255.map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)256.map(prepare_sample, num_parallel_calls=AUTOTUNE)257.shuffle(batch_size * 10)258.batch(batch_size)259.prefetch(AUTOTUNE)260)261return dataset262263264train_filenames = tf.io.gfile.glob(f"{tfrecords_dir}/*.tfrec")265batch_size = 32266epochs = 1267steps_per_epoch = 50268AUTOTUNE = tf.data.AUTOTUNE269270input_tensor = keras.layers.Input(shape=(224, 224, 3), name="image")271model = keras.applications.EfficientNetB0(272input_tensor=input_tensor, weights=None, classes=91273)274275276model.compile(277optimizer=keras.optimizers.Adam(),278loss=keras.losses.SparseCategoricalCrossentropy(),279metrics=[keras.metrics.SparseCategoricalAccuracy()],280)281282283model.fit(284x=get_dataset(train_filenames, batch_size),285epochs=epochs,286steps_per_epoch=steps_per_epoch,287verbose=1,288)289290"""291## Conclusion292293This example demonstrates that instead of reading images and annotations from different294sources you can have your data coming from a single source thanks to TFRecord.295This process can make storing and reading data simpler and more efficient.296For more information, you can go to the [TFRecord and297tf.train.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord) tutorial.298"""299300301