Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132926 views
License: OTHER
1
""" Examples to demonstrate how to write an image file to a TFRecord,
2
and how to read a TFRecord file using TFRecordReader.
3
Author: Chip Huyen
4
Prepared for the class CS 20SI: "TensorFlow for Deep Learning Research"
5
cs20si.stanford.edu
6
"""
7
import os
8
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
9
10
import sys
11
sys.path.append('..')
12
13
from PIL import Image
14
import numpy as np
15
import matplotlib.pyplot as plt
16
import tensorflow as tf
17
18
# image supposed to have shape: 480 x 640 x 3 = 921600
19
IMAGE_PATH = 'data/'
20
21
def _int64_feature(value):
22
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
23
24
def _bytes_feature(value):
25
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
26
27
def get_image_binary(filename):
28
""" You can read in the image using tensorflow too, but it's a drag
29
since you have to create graphs. It's much easier using Pillow and NumPy
30
"""
31
image = Image.open(filename)
32
image = np.asarray(image, np.uint8)
33
shape = np.array(image.shape, np.int32)
34
return shape.tobytes(), image.tobytes() # convert image to raw data bytes in the array.
35
36
def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
37
""" This example is to write a sample to TFRecord file. If you want to write
38
more samples, just use a loop.
39
"""
40
writer = tf.python_io.TFRecordWriter(tfrecord_file)
41
# write label, shape, and image content to the TFRecord file
42
example = tf.train.Example(features=tf.train.Features(feature={
43
'label': _int64_feature(label),
44
'shape': _bytes_feature(shape),
45
'image': _bytes_feature(binary_image)
46
}))
47
writer.write(example.SerializeToString())
48
writer.close()
49
50
def write_tfrecord(label, image_file, tfrecord_file):
51
shape, binary_image = get_image_binary(image_file)
52
write_to_tfrecord(label, shape, binary_image, tfrecord_file)
53
54
def read_from_tfrecord(filenames):
55
tfrecord_file_queue = tf.train.string_input_producer(filenames, name='queue')
56
reader = tf.TFRecordReader()
57
_, tfrecord_serialized = reader.read(tfrecord_file_queue)
58
59
# label and image are stored as bytes but could be stored as
60
# int64 or float64 values in a serialized tf.Example protobuf.
61
tfrecord_features = tf.parse_single_example(tfrecord_serialized,
62
features={
63
'label': tf.FixedLenFeature([], tf.int64),
64
'shape': tf.FixedLenFeature([], tf.string),
65
'image': tf.FixedLenFeature([], tf.string),
66
}, name='features')
67
# image was saved as uint8, so we have to decode as uint8.
68
image = tf.decode_raw(tfrecord_features['image'], tf.uint8)
69
shape = tf.decode_raw(tfrecord_features['shape'], tf.int32)
70
# the image tensor is flattened out, so we have to reconstruct the shape
71
image = tf.reshape(image, shape)
72
label = tfrecord_features['label']
73
return label, shape, image
74
75
def read_tfrecord(tfrecord_file):
76
label, shape, image = read_from_tfrecord([tfrecord_file])
77
78
with tf.Session() as sess:
79
coord = tf.train.Coordinator()
80
threads = tf.train.start_queue_runners(coord=coord)
81
label, image, shape = sess.run([label, image, shape])
82
coord.request_stop()
83
coord.join(threads)
84
print(label)
85
print(shape)
86
plt.imshow(image)
87
plt.show()
88
89
def main():
90
# assume the image has the label Chihuahua, which corresponds to class number 1
91
label = 1
92
image_file = IMAGE_PATH + 'friday.jpg'
93
tfrecord_file = IMAGE_PATH + 'friday.tfrecord'
94
write_tfrecord(label, image_file, tfrecord_file)
95
read_tfrecord(tfrecord_file)
96
97
if __name__ == '__main__':
98
main()
99
100
101