Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/barlow_twins.py
3507 views
1
"""
2
Title: Barlow Twins for Contrastive SSL
3
Author: [Abhiraam Eranti](https://github.com/dewball345)
4
Date created: 11/4/21
5
Last modified: 12/20/21
6
Description: A keras implementation of Barlow Twins (constrastive SSL with redundancy reduction).
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
"""
13
14
"""
15
Self-supervised learning (SSL) is a relatively novel technique in which a model
16
learns from unlabeled data, and is often used when the data is corrupted or
17
if there is very little of it. A practical use for SSL is to create
18
intermediate embeddings that are learned from the data. These embeddings are
19
based on the dataset itself, with similar images having similar embeddings, and
20
vice versa. They are then attached to the rest of the model, which uses those
21
embeddings as information and effectively learns and makes predictions properly.
22
These embeddings, ideally, should contain as much information and insight about
23
the data as possible, so that the model can make better predictions. However,
24
a common problem that arises is that the model creates embeddings that are
25
redundant. For example, if two images are similar, the model will create
26
embeddings that are just a string of 1's, or some other value that
27
contains repeating bits of information. This is no better than a one-hot
28
encoding or just having one bit as the model’s representations; it defeats the
29
purpose of the embeddings, as they do not learn as much about the dataset as
30
possible. For other approaches, the solution to the problem was to carefully
31
configure the model such that it tries not to be redundant.
32
33
34
Barlow Twins is a new approach to this problem; while other solutions mainly
35
tackle the first goal of invariance (similar images have similar embeddings),
36
the Barlow Twins method also prioritizes the goal of reducing redundancy.
37
38
It also has the advantage of being much simpler than other methods, and its
39
model architecture is symmetric, meaning that both twins in the model do the
40
same thing. It is also near state-of-the-art on imagenet, even exceeding methods
41
like SimCLR.
42
43
44
One disadvantage of Barlow Twins is that it is heavily dependent on
45
augmentation, suffering major performance decreases in accuracy without them.
46
47
TL, DR: Barlow twins creates representations that are:
48
49
* Invariant.
50
* Not redundant, and carry as much info about the dataset.
51
52
Also, it is simpler than other methods.
53
54
This notebook can train a Barlow Twins model and reach up to
55
64% validation accuracy on the CIFAR-10 dataset.
56
"""
57
58
"""
59
![image](https://i.imgur.com/G6LnEPT.png)
60
61
62
63
64
65
66
"""
67
68
"""
69
### High-Level Theory
70
71
72
"""
73
74
"""
75
The model takes two versions of the same image(with different augmentations) as
76
input. Then it takes a prediction of each of them, creating representations.
77
They are then used to make a cross-correlation matrix.
78
79
Cross-correlation matrix:
80
```
81
(pred_1.T @ pred_2) / batch_size
82
```
83
84
The cross-correlation matrix measures the correlation between the output
85
neurons in the two representations made by the model predictions of the two
86
augmented versions of data. Ideally, a cross-correlation matrix should look
87
like an identity matrix if the two images are the same.
88
89
When this happens, it means that the representations:
90
91
1. Are invariant. The diagonal shows the correlation between each
92
representation's neurons and its corresponding augmented one. Because the two
93
versions come from the same image, the diagonal of the matrix should show that
94
there is a strong correlation between them. If the images are different, there
95
shouldn't be a diagonal.
96
2. Do not show signs of redundancy. If the neurons show correlation with a
97
non-diagonal neuron, it means that it is not correctly identifying similarities
98
between the two augmented images. This means that it is redundant.
99
100
Here is a good way of understanding in pseudocode(information from the original
101
paper):
102
103
```
104
c[i][i] = 1
105
c[i][j] = 0
106
107
where:
108
c is the cross-correlation matrix
109
i is the index of one representation's neuron
110
j is the index of the second representation's neuron
111
```
112
"""
113
114
"""
115
Taken from the original paper: [Barlow Twins: Self-Supervised Learning via Redundancy
116
Reduction](https://arxiv.org/abs/2103.03230)
117
"""
118
119
"""
120
### References
121
"""
122
123
"""
124
Paper:
125
[Barlow Twins: Self-Supervised Learning via Redundancy
126
Reduction](https://arxiv.org/abs/2103.03230)
127
128
Original Implementation:
129
[facebookresearch/barlowtwins](https://github.com/facebookresearch/barlowtwins)
130
131
132
"""
133
134
"""
135
## Setup
136
"""
137
138
"""shell
139
pip install tensorflow-addons
140
"""
141
142
import os
143
144
# slightly faster improvements, on the first epoch 30 second decrease and a 1-2 second
145
# decrease in epoch time. Overall saves approx. 5 min of training time
146
147
# Allocates two threads for a gpu private which allows more operations to be
148
# done faster
149
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
150
151
import tensorflow as tf # framework
152
from tensorflow import keras # for tf.keras
153
import tensorflow_addons as tfa # LAMB optimizer and gaussian_blur_2d function
154
import numpy as np # np.random.random
155
import matplotlib.pyplot as plt # graphs
156
import datetime # tensorboard logs naming
157
158
# XLA optimization for faster performance(up to 10-15 minutes total time saved)
159
tf.config.optimizer.set_jit(True)
160
161
"""
162
## Load the CIFAR-10 dataset
163
"""
164
165
[
166
(train_features, train_labels),
167
(test_features, test_labels),
168
] = keras.datasets.cifar10.load_data()
169
170
train_features = train_features / 255.0
171
test_features = test_features / 255.0
172
173
"""
174
## Necessary Hyperparameters
175
"""
176
177
# Batch size of dataset
178
BATCH_SIZE = 512
179
# Width and height of image
180
IMAGE_SIZE = 32
181
182
"""
183
## Augmentation Utilities
184
The Barlow twins algorithm is heavily reliant on
185
Augmentation. One unique feature of the method is that sometimes, augmentations
186
probabilistically occur.
187
188
**Augmentations**
189
190
* *RandomToGrayscale*: randomly applies grayscale to image 20% of the time
191
* *RandomColorJitter*: randomly applies color jitter 80% of the time
192
* *RandomFlip*: randomly flips image horizontally 50% of the time
193
* *RandomResizedCrop*: randomly crops an image to a random size then resizes. This
194
happens 100% of the time
195
* *RandomSolarize*: randomly applies solarization to an image 20% of the time
196
* *RandomBlur*: randomly blurs an image 20% of the time
197
"""
198
199
200
class Augmentation(keras.layers.Layer):
201
"""Base augmentation class.
202
203
Base augmentation class. Contains the random_execute method.
204
205
Methods:
206
random_execute: method that returns true or false based
207
on a probability. Used to determine whether an augmentation
208
will be run.
209
"""
210
211
def __init__(self):
212
super().__init__()
213
214
@tf.function
215
def random_execute(self, prob: float) -> bool:
216
"""random_execute function.
217
218
Arguments:
219
prob: a float value from 0-1 that determines the
220
probability.
221
222
Returns:
223
returns true or false based on the probability.
224
"""
225
226
return tf.random.uniform([], minval=0, maxval=1) < prob
227
228
229
class RandomToGrayscale(Augmentation):
230
"""RandomToGrayscale class.
231
232
RandomToGrayscale class. Randomly makes an image
233
grayscaled based on the random_execute method. There
234
is a 20% chance that an image will be grayscaled.
235
236
Methods:
237
call: method that grayscales an image 20% of
238
the time.
239
"""
240
241
@tf.function
242
def call(self, x: tf.Tensor) -> tf.Tensor:
243
"""call function.
244
245
Arguments:
246
x: a tf.Tensor representing the image.
247
248
Returns:
249
returns a grayscaled version of the image 20% of the time
250
and the original image 80% of the time.
251
"""
252
253
if self.random_execute(0.2):
254
x = tf.image.rgb_to_grayscale(x)
255
x = tf.tile(x, [1, 1, 3])
256
return x
257
258
259
class RandomColorJitter(Augmentation):
260
"""RandomColorJitter class.
261
262
RandomColorJitter class. Randomly adds color jitter to an image.
263
Color jitter means to add random brightness, contrast,
264
saturation, and hue to an image. There is a 80% chance that an
265
image will be randomly color-jittered.
266
267
Methods:
268
call: method that color-jitters an image 80% of
269
the time.
270
"""
271
272
@tf.function
273
def call(self, x: tf.Tensor) -> tf.Tensor:
274
"""call function.
275
276
Adds color jitter to image, including:
277
Brightness change by a max-delta of 0.8
278
Contrast change by a max-delta of 0.8
279
Saturation change by a max-delta of 0.8
280
Hue change by a max-delta of 0.2
281
Originally, the same deltas of the original paper
282
were used, but a performance boost of almost 2% was found
283
when doubling them.
284
285
Arguments:
286
x: a tf.Tensor representing the image.
287
288
Returns:
289
returns a color-jittered version of the image 80% of the time
290
and the original image 20% of the time.
291
"""
292
293
if self.random_execute(0.8):
294
x = tf.image.random_brightness(x, 0.8)
295
x = tf.image.random_contrast(x, 0.4, 1.6)
296
x = tf.image.random_saturation(x, 0.4, 1.6)
297
x = tf.image.random_hue(x, 0.2)
298
return x
299
300
301
class RandomFlip(Augmentation):
302
"""RandomFlip class.
303
304
RandomFlip class. Randomly flips image horizontally. There is a 50%
305
chance that an image will be randomly flipped.
306
307
Methods:
308
call: method that flips an image 50% of
309
the time.
310
"""
311
312
@tf.function
313
def call(self, x: tf.Tensor) -> tf.Tensor:
314
"""call function.
315
316
Randomly flips the image.
317
318
Arguments:
319
x: a tf.Tensor representing the image.
320
321
Returns:
322
returns a flipped version of the image 50% of the time
323
and the original image 50% of the time.
324
"""
325
326
if self.random_execute(0.5):
327
x = tf.image.random_flip_left_right(x)
328
return x
329
330
331
class RandomResizedCrop(Augmentation):
332
"""RandomResizedCrop class.
333
334
RandomResizedCrop class. Randomly crop an image to a random size,
335
then resize the image back to the original size.
336
337
Attributes:
338
image_size: The dimension of the image
339
340
Methods:
341
__call__: method that does random resize crop to the image.
342
"""
343
344
def __init__(self, image_size):
345
super().__init__()
346
self.image_size = image_size
347
348
def call(self, x: tf.Tensor) -> tf.Tensor:
349
"""call function.
350
351
Does random resize crop by randomly cropping an image to a random
352
size 75% - 100% the size of the image. Then resizes it.
353
354
Arguments:
355
x: a tf.Tensor representing the image.
356
357
Returns:
358
returns a randomly cropped image.
359
"""
360
361
rand_size = tf.random.uniform(
362
shape=[],
363
minval=int(0.75 * self.image_size),
364
maxval=1 * self.image_size,
365
dtype=tf.int32,
366
)
367
368
crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
369
crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
370
return crop_resize
371
372
373
class RandomSolarize(Augmentation):
374
"""RandomSolarize class.
375
376
RandomSolarize class. Randomly solarizes an image.
377
Solarization is when pixels accidentally flip to an inverted state.
378
379
Methods:
380
call: method that does random solarization 20% of the time.
381
"""
382
383
@tf.function
384
def call(self, x: tf.Tensor) -> tf.Tensor:
385
"""call function.
386
387
Randomly solarizes the image.
388
389
Arguments:
390
x: a tf.Tensor representing the image.
391
392
Returns:
393
returns a solarized version of the image 20% of the time
394
and the original image 80% of the time.
395
"""
396
397
if self.random_execute(0.2):
398
# flips abnormally low pixels to abnormally high pixels
399
x = tf.where(x < 10, x, 255 - x)
400
return x
401
402
403
class RandomBlur(Augmentation):
404
"""RandomBlur class.
405
406
RandomBlur class. Randomly blurs an image.
407
408
Methods:
409
call: method that does random blur 20% of the time.
410
"""
411
412
@tf.function
413
def call(self, x: tf.Tensor) -> tf.Tensor:
414
"""call function.
415
416
Randomly solarizes the image.
417
418
Arguments:
419
x: a tf.Tensor representing the image.
420
421
Returns:
422
returns a blurred version of the image 20% of the time
423
and the original image 80% of the time.
424
"""
425
426
if self.random_execute(0.2):
427
s = np.random.random()
428
return tfa.image.gaussian_filter2d(image=x, sigma=s)
429
return x
430
431
432
class RandomAugmentor(keras.Model):
433
"""RandomAugmentor class.
434
435
RandomAugmentor class. Chains all the augmentations into
436
one pipeline.
437
438
Attributes:
439
image_size: An integer represing the width and height
440
of the image. Designed to be used for square images.
441
random_resized_crop: Instance variable representing the
442
RandomResizedCrop layer.
443
random_flip: Instance variable representing the
444
RandomFlip layer.
445
random_color_jitter: Instance variable representing the
446
RandomColorJitter layer.
447
random_blur: Instance variable representing the
448
RandomBlur layer
449
random_to_grayscale: Instance variable representing the
450
RandomToGrayscale layer
451
random_solarize: Instance variable representing the
452
RandomSolarize layer
453
454
Methods:
455
call: chains layers in pipeline together
456
"""
457
458
def __init__(self, image_size: int):
459
super().__init__()
460
461
self.image_size = image_size
462
self.random_resized_crop = RandomResizedCrop(image_size)
463
self.random_flip = RandomFlip()
464
self.random_color_jitter = RandomColorJitter()
465
self.random_blur = RandomBlur()
466
self.random_to_grayscale = RandomToGrayscale()
467
self.random_solarize = RandomSolarize()
468
469
def call(self, x: tf.Tensor) -> tf.Tensor:
470
x = self.random_resized_crop(x)
471
x = self.random_flip(x)
472
x = self.random_color_jitter(x)
473
x = self.random_blur(x)
474
x = self.random_to_grayscale(x)
475
x = self.random_solarize(x)
476
477
x = tf.clip_by_value(x, 0, 1)
478
return x
479
480
481
bt_augmentor = RandomAugmentor(IMAGE_SIZE)
482
483
"""
484
## Data Loading
485
486
A class that creates the barlow twins' dataset.
487
488
The dataset consists of two copies of each image, with each copy receiving different
489
augmentations.
490
"""
491
492
493
class BTDatasetCreator:
494
"""Barlow twins dataset creator class.
495
496
BTDatasetCreator class. Responsible for creating the
497
barlow twins' dataset.
498
499
Attributes:
500
options: tf.data.Options needed to configure a setting
501
that may improve performance.
502
seed: random seed for shuffling. Used to synchronize two
503
augmented versions.
504
augmentor: augmentor used for augmentation.
505
506
Methods:
507
__call__: creates barlow dataset.
508
augmented_version: creates 1 half of the dataset.
509
"""
510
511
def __init__(self, augmentor: RandomAugmentor, seed: int = 1024):
512
self.options = tf.data.Options()
513
self.options.threading.max_intra_op_parallelism = 1
514
self.seed = seed
515
self.augmentor = augmentor
516
517
def augmented_version(self, ds: list) -> tf.data.Dataset:
518
return (
519
tf.data.Dataset.from_tensor_slices(ds)
520
.shuffle(1000, seed=self.seed)
521
.map(self.augmentor, num_parallel_calls=tf.data.AUTOTUNE)
522
.batch(BATCH_SIZE, drop_remainder=True)
523
.prefetch(tf.data.AUTOTUNE)
524
.with_options(self.options)
525
)
526
527
def __call__(self, ds: list) -> tf.data.Dataset:
528
a1 = self.augmented_version(ds)
529
a2 = self.augmented_version(ds)
530
531
return tf.data.Dataset.zip((a1, a2)).with_options(self.options)
532
533
534
augment_versions = BTDatasetCreator(bt_augmentor)(train_features)
535
536
"""
537
View examples of dataset.
538
"""
539
540
sample_augment_versions = iter(augment_versions)
541
542
543
def plot_values(batch: tuple):
544
fig, axs = plt.subplots(3, 3)
545
fig1, axs1 = plt.subplots(3, 3)
546
547
fig.suptitle("Augmentation 1")
548
fig1.suptitle("Augmentation 2")
549
550
a1, a2 = batch
551
552
# plots images on both tables
553
for i in range(3):
554
for j in range(3):
555
# CHANGE(add / 255)
556
axs[i][j].imshow(a1[3 * i + j])
557
axs[i][j].axis("off")
558
axs1[i][j].imshow(a2[3 * i + j])
559
axs1[i][j].axis("off")
560
561
plt.show()
562
563
564
plot_values(next(sample_augment_versions))
565
566
"""
567
## Pseudocode of loss and model
568
The following sections follow the original author's pseudocode containing both model and
569
loss functions(see diagram below). Also contains a reference of variables used.
570
"""
571
572
"""
573
![pseudocode](https://i.imgur.com/Tlrootj.png)
574
"""
575
576
"""
577
Reference:
578
579
```
580
y_a: first augmented version of original image.
581
y_b: second augmented version of original image.
582
z_a: model representation(embeddings) of y_a.
583
z_b: model representation(embeddings) of y_b.
584
z_a_norm: normalized z_a.
585
z_b_norm: normalized z_b.
586
c: cross correlation matrix.
587
c_diff: diagonal portion of loss(invariance term).
588
off_diag: off-diagonal portion of loss(redundancy reduction term).
589
```
590
"""
591
592
"""
593
## BarlowLoss: barlow twins model's loss function
594
595
Barlow Twins uses the cross correlation matrix for its loss. There are two parts to the
596
loss function:
597
598
* ***The invariance term***(diagonal). This part is used to make the diagonals of the
599
matrix into 1s. When this is the case, the matrix shows that the images are
600
correlated(same).
601
* The loss function subtracts 1 from the diagonal and squares the values.
602
* ***The redundancy reduction term***(off-diagonal). Here, the barlow twins loss
603
function aims to make these values zero. As mentioned before, it is redundant if the
604
representation neurons are correlated with values that are not on the diagonal.
605
* Off diagonals are squared.
606
607
After this the two parts are summed together.
608
609
610
611
612
"""
613
614
615
class BarlowLoss(keras.losses.Loss):
616
"""BarlowLoss class.
617
618
BarlowLoss class. Creates a loss function based on the cross-correlation
619
matrix.
620
621
Attributes:
622
batch_size: the batch size of the dataset
623
lambda_amt: the value for lambda(used in cross_corr_matrix_loss)
624
625
Methods:
626
__init__: gets instance variables
627
call: gets the loss based on the cross-correlation matrix
628
make_diag_zeros: Used in calculating off-diagonal section
629
of loss function; makes diagonals zeros.
630
cross_corr_matrix_loss: creates loss based on cross correlation
631
matrix.
632
"""
633
634
def __init__(self, batch_size: int):
635
"""__init__ method.
636
637
Gets the instance variables
638
639
Arguments:
640
batch_size: An integer value representing the batch size of the
641
dataset. Used for cross correlation matrix calculation.
642
"""
643
644
super().__init__()
645
self.lambda_amt = 5e-3
646
self.batch_size = batch_size
647
648
def get_off_diag(self, c: tf.Tensor) -> tf.Tensor:
649
"""get_off_diag method.
650
651
Makes the diagonals of the cross correlation matrix zeros.
652
This is used in the off-diagonal portion of the loss function,
653
where we take the squares of the off-diagonal values and sum them.
654
655
Arguments:
656
c: A tf.tensor that represents the cross correlation
657
matrix
658
659
Returns:
660
Returns a tf.tensor which represents the cross correlation
661
matrix with its diagonals as zeros.
662
"""
663
664
zero_diag = tf.zeros(c.shape[-1])
665
return tf.linalg.set_diag(c, zero_diag)
666
667
def cross_corr_matrix_loss(self, c: tf.Tensor) -> tf.Tensor:
668
"""cross_corr_matrix_loss method.
669
670
Gets the loss based on the cross correlation matrix.
671
We want the diagonals to be 1's and everything else to be
672
zeros to show that the two augmented images are similar.
673
674
Loss function procedure:
675
take the diagonal of the cross-correlation matrix, subtract by 1,
676
and square that value so no negatives.
677
678
Take the off-diagonal of the cc-matrix(see get_off_diag()),
679
square those values to get rid of negatives and increase the value,
680
and multiply it by a lambda to weight it such that it is of equal
681
value to the optimizer as the diagonal(there are more values off-diag
682
then on-diag)
683
684
Take the sum of the first and second parts and then sum them together.
685
686
Arguments:
687
c: A tf.tensor that represents the cross correlation
688
matrix
689
690
Returns:
691
Returns a tf.tensor which represents the cross correlation
692
matrix with its diagonals as zeros.
693
"""
694
695
# subtracts diagonals by one and squares them(first part)
696
c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2)
697
698
# takes off diagonal, squares it, multiplies with lambda(second part)
699
off_diag = tf.pow(self.get_off_diag(c), 2) * self.lambda_amt
700
701
# sum first and second parts together
702
loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag)
703
704
return loss
705
706
def normalize(self, output: tf.Tensor) -> tf.Tensor:
707
"""normalize method.
708
709
Normalizes the model prediction.
710
711
Arguments:
712
output: the model prediction.
713
714
Returns:
715
Returns a normalized version of the model prediction.
716
"""
717
718
return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std(
719
output, axis=0
720
)
721
722
def cross_corr_matrix(self, z_a_norm: tf.Tensor, z_b_norm: tf.Tensor) -> tf.Tensor:
723
"""cross_corr_matrix method.
724
725
Creates a cross correlation matrix from the predictions.
726
It transposes the first prediction and multiplies this with
727
the second, creating a matrix with shape (n_dense_units, n_dense_units).
728
See build_twin() for more info. Then it divides this with the
729
batch size.
730
731
Arguments:
732
z_a_norm: A normalized version of the first prediction.
733
z_b_norm: A normalized version of the second prediction.
734
735
Returns:
736
Returns a cross correlation matrix.
737
"""
738
return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size
739
740
def call(self, z_a: tf.Tensor, z_b: tf.Tensor) -> tf.Tensor:
741
"""call method.
742
743
Makes the cross-correlation loss. Uses the CreateCrossCorr
744
class to make the cross corr matrix, then finds the loss and
745
returns it(see cross_corr_matrix_loss()).
746
747
Arguments:
748
z_a: The prediction of the first set of augmented data.
749
z_b: the prediction of the second set of augmented data.
750
751
Returns:
752
Returns a (rank-0) tf.Tensor that represents the loss.
753
"""
754
755
z_a_norm, z_b_norm = self.normalize(z_a), self.normalize(z_b)
756
c = self.cross_corr_matrix(z_a_norm, z_b_norm)
757
loss = self.cross_corr_matrix_loss(c)
758
return loss
759
760
761
"""
762
## Barlow Twins' Model Architecture
763
The model has two parts:
764
765
* The encoder network, which is a resnet-34.
766
* The projector network, which creates the model embeddings.
767
* This consists of an MLP with 3 dense-batchnorm-relu layers.
768
"""
769
770
"""
771
Resnet encoder network implementation:
772
"""
773
774
775
class ResNet34:
776
"""Resnet34 class.
777
778
Responsible for the Resnet 34 architecture.
779
Modified from
780
https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
781
https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
782
View their website for more information.
783
"""
784
785
def identity_block(self, x, filter):
786
# copy tensor to variable called x_skip
787
x_skip = x
788
# Layer 1
789
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
790
x = tf.keras.layers.BatchNormalization(axis=3)(x)
791
x = tf.keras.layers.Activation("relu")(x)
792
# Layer 2
793
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
794
x = tf.keras.layers.BatchNormalization(axis=3)(x)
795
# Add Residue
796
x = tf.keras.layers.Add()([x, x_skip])
797
x = tf.keras.layers.Activation("relu")(x)
798
return x
799
800
def convolutional_block(self, x, filter):
801
# copy tensor to variable called x_skip
802
x_skip = x
803
# Layer 1
804
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x)
805
x = tf.keras.layers.BatchNormalization(axis=3)(x)
806
x = tf.keras.layers.Activation("relu")(x)
807
# Layer 2
808
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
809
x = tf.keras.layers.BatchNormalization(axis=3)(x)
810
# Processing Residue with conv(1,1)
811
x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip)
812
# Add Residue
813
x = tf.keras.layers.Add()([x, x_skip])
814
x = tf.keras.layers.Activation("relu")(x)
815
return x
816
817
def __call__(self, shape=(32, 32, 3)):
818
# Step 1 (Setup Input Layer)
819
x_input = tf.keras.layers.Input(shape)
820
x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
821
# Step 2 (Initial Conv layer along with maxPool)
822
x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x)
823
x = tf.keras.layers.BatchNormalization()(x)
824
x = tf.keras.layers.Activation("relu")(x)
825
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x)
826
# Define size of sub-blocks and initial filter size
827
block_layers = [3, 4, 6, 3]
828
filter_size = 64
829
# Step 3 Add the Resnet Blocks
830
for i in range(4):
831
if i == 0:
832
# For sub-block 1 Residual/Convolutional block not needed
833
for j in range(block_layers[i]):
834
x = self.identity_block(x, filter_size)
835
else:
836
# One Residual/Convolutional Block followed by Identity blocks
837
# The filter size will go on increasing by a factor of 2
838
filter_size = filter_size * 2
839
x = self.convolutional_block(x, filter_size)
840
for j in range(block_layers[i] - 1):
841
x = self.identity_block(x, filter_size)
842
# Step 4 End Dense Network
843
x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x)
844
x = tf.keras.layers.Flatten()(x)
845
model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34")
846
return model
847
848
849
"""
850
Projector network:
851
"""
852
853
854
def build_twin() -> keras.Model:
855
"""build_twin method.
856
857
Builds a barlow twins model consisting of an encoder(resnet-34)
858
and a projector, which generates embeddings for the images
859
860
Returns:
861
returns a barlow twins model
862
"""
863
864
# number of dense neurons in the projector
865
n_dense_neurons = 5000
866
867
# encoder network
868
resnet = ResNet34()()
869
last_layer = resnet.layers[-1].output
870
871
# intermediate layers of the projector network
872
n_layers = 2
873
for i in range(n_layers):
874
dense = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{i}")
875
if i == 0:
876
x = dense(last_layer)
877
else:
878
x = dense(x)
879
x = tf.keras.layers.BatchNormalization(name=f"projector_bn_{i}")(x)
880
x = tf.keras.layers.ReLU(name=f"projector_relu_{i}")(x)
881
882
x = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{n_layers}")(x)
883
884
model = keras.Model(resnet.input, x)
885
return model
886
887
888
"""
889
## Training Loop Model
890
891
See pseudocode for reference.
892
"""
893
894
895
class BarlowModel(keras.Model):
896
"""BarlowModel class.
897
898
BarlowModel class. Responsible for making predictions and handling
899
gradient descent with the optimizer.
900
901
Attributes:
902
model: the barlow model architecture.
903
loss_tracker: the loss metric.
904
905
Methods:
906
train_step: one train step; do model predictions, loss, and
907
optimizer step.
908
metrics: Returns metrics.
909
"""
910
911
def __init__(self):
912
super().__init__()
913
self.model = build_twin()
914
self.loss_tracker = keras.metrics.Mean(name="loss")
915
916
@property
917
def metrics(self):
918
return [self.loss_tracker]
919
920
def train_step(self, batch: tf.Tensor) -> tf.Tensor:
921
"""train_step method.
922
923
Do one train step. Make model predictions, find loss, pass loss to
924
optimizer, and make optimizer apply gradients.
925
926
Arguments:
927
batch: one batch of data to be given to the loss function.
928
929
Returns:
930
Returns a dictionary with the loss metric.
931
"""
932
933
# get the two augmentations from the batch
934
y_a, y_b = batch
935
936
with tf.GradientTape() as tape:
937
# get two versions of predictions
938
z_a, z_b = self.model(y_a, training=True), self.model(y_b, training=True)
939
loss = self.loss(z_a, z_b)
940
941
grads_model = tape.gradient(loss, self.model.trainable_variables)
942
943
self.optimizer.apply_gradients(zip(grads_model, self.model.trainable_variables))
944
self.loss_tracker.update_state(loss)
945
946
return {"loss": self.loss_tracker.result()}
947
948
949
"""
950
## Model Training
951
952
* Used the LAMB optimizer, instead of ADAM or SGD.
953
* Similar to the LARS optimizer used in the paper, and lets the model converge much
954
faster than other methods.
955
* Expected training time: 1 hour 30 min. Go and eat a snack or take a nap or something.
956
"""
957
958
# sets up model, optimizer, loss
959
960
bm = BarlowModel()
961
# chose the LAMB optimizer due to high batch sizes. Converged MUCH faster
962
# than ADAM or SGD
963
optimizer = tfa.optimizers.LAMB()
964
loss = BarlowLoss(BATCH_SIZE)
965
966
bm.compile(optimizer=optimizer, loss=loss)
967
968
# Expected training time: 1 hours 30 min
969
970
history = bm.fit(augment_versions, epochs=160)
971
plt.plot(history.history["loss"])
972
plt.show()
973
974
"""
975
## Evaluation
976
977
**Linear evaluation:** to evaluate the model's performance, we add
978
a linear dense layer at the end and freeze the main model's weights, only letting the
979
dense layer to be tuned. If the model actually learned something, then the accuracy would
980
be significantly higher than random chance.
981
982
**Accuracy on CIFAR-10** : 64% for this notebook. This is much better than the 10% we get
983
from random guessing.
984
"""
985
986
# Approx: 64% accuracy with this barlow twins model.
987
988
xy_ds = (
989
tf.data.Dataset.from_tensor_slices((train_features, train_labels))
990
.shuffle(1000)
991
.batch(BATCH_SIZE, drop_remainder=True)
992
.prefetch(tf.data.AUTOTUNE)
993
)
994
995
test_ds = (
996
tf.data.Dataset.from_tensor_slices((test_features, test_labels))
997
.shuffle(1000)
998
.batch(BATCH_SIZE, drop_remainder=True)
999
.prefetch(tf.data.AUTOTUNE)
1000
)
1001
1002
model = keras.models.Sequential(
1003
[
1004
bm.model,
1005
keras.layers.Dense(
1006
10, activation="softmax", kernel_regularizer=keras.regularizers.l2(0.02)
1007
),
1008
]
1009
)
1010
1011
model.layers[0].trainable = False
1012
1013
linear_optimizer = tfa.optimizers.LAMB()
1014
model.compile(
1015
optimizer=linear_optimizer,
1016
loss="sparse_categorical_crossentropy",
1017
metrics=["accuracy"],
1018
)
1019
1020
model.fit(xy_ds, epochs=35, validation_data=test_ds)
1021
1022
"""
1023
## Conclusion
1024
1025
* Barlow Twins is a simple and concise method for contrastive and self-supervised
1026
learning.
1027
* With this resnet-34 model architecture, we were able to reach 62-64% validation
1028
accuracy.
1029
1030
## Use-Cases of Barlow-Twins(and contrastive learning in General)
1031
1032
* Semi-supervised learning: You can see that this model gave a 62-64% boost in accuracy
1033
when it wasn't even trained with the labels. It can be used when you have little labeled
1034
data but a lot of unlabeled data.
1035
* You do barlow twins training on the unlabeled data, and then you do secondary training
1036
with the labeled data.
1037
1038
## Helpful links
1039
1040
* [Paper](https://arxiv.org/abs/2103.03230)
1041
* [Original Pytorch Implementation](https://github.com/facebookresearch/barlowtwins)
1042
* [Sayak Paul's Implementation](https://colab.research.google.com/github/sayakpaul/Barlow-Twins-TF/blob/main/Barlow_Twins.ipynb#scrollTo=GlWepkM8_prl).
1043
* Thanks to Sayak Paul for his implementation. It helped me with debugging and
1044
comparisons of accuracy, loss.
1045
* [resnet34 implementation](https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2)
1046
* Thanks to Yashowardhan Shinde for writing the article.
1047
1048
1049
1050
"""
1051
1052