</div>
---
We create three datasets:
1. A dataset with a smaller resolution - 128x128.
2. Two datasets with a larger resolution - 224x224.
We will apply different augmentation transforms to the larger-resolution datasets.
The idea of FixRes is to first train a model on a smaller resolution dataset and then fine-tune
it on a larger resolution dataset. This simple yet effective recipe leads to non-trivial performance
improvements. Please refer to the [original paper](https://arxiv.org/abs/1906.06423) for
results.
```python
batch_size = 32
auto = tf.data.AUTOTUNE
smaller_size = 128
bigger_size = 224
size_for_resizing = int((bigger_size / smaller_size) * bigger_size)
central_crop_layer = layers.CenterCrop(bigger_size, bigger_size)
def preprocess_initial(train, image_size):
"""Initial preprocessing function for training on smaller resolution.
For training, do random_horizontal_flip -> random_crop.
For validation, just resize.
No color-jittering has been used.
"""
def _pp(image, label, train):
if train:
channels = image.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0,
use_image_if_no_bounding_boxes=True,
)
image = tf.slice(image, begin, size)
image.set_shape([None, None, channels])
image = tf.image.resize(image, [image_size, image_size])
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, [image_size, image_size])
return image, label
return _pp
def preprocess_finetune(image, label, train):
"""Preprocessing function for fine-tuning on a higher resolution.
For training, resize to a bigger resolution to maintain the ratio ->
random_horizontal_flip -> center_crop.
For validation, do the same without any horizontal flipping.
No color-jittering has been used.
"""
image = tf.image.resize(image, [size_for_resizing, size_for_resizing])
if train:
image = tf.image.random_flip_left_right(image)
image = central_crop_layer(image[None, ...])[0]
return image, label
def make_dataset(
dataset: tf.data.Dataset,
train: bool,
image_size: int = smaller_size,
fixres: bool = True,
num_parallel_calls=auto,
):
if image_size not in [smaller_size, bigger_size]:
raise ValueError(f"{image_size} resolution is not supported.")
if image_size == smaller_size:
preprocess_func = preprocess_initial(train, image_size)
elif not fixres and image_size == bigger_size:
preprocess_func = preprocess_initial(train, image_size)
else:
preprocess_func = preprocess_finetune
dataset = dataset.map(
lambda x, y: preprocess_func(x, y, train),
num_parallel_calls=num_parallel_calls,
)
dataset = dataset.batch(batch_size)
if train:
dataset = dataset.shuffle(batch_size * 10)
return dataset.prefetch(num_parallel_calls)