Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/knowledge_distillation_tutorial.py
Views: 1140
1
# -*- coding: utf-8 -*-
2
"""
3
Knowledge Distillation Tutorial
4
===============================
5
**Author**: `Alexandros Chariton <https://github.com/AlexandrosChrtn>`_
6
"""
7
8
######################################################################
9
# Knowledge distillation is a technique that enables knowledge transfer from large, computationally expensive
10
# models to smaller ones without losing validity. This allows for deployment on less powerful
11
# hardware, making evaluation faster and more efficient.
12
#
13
# In this tutorial, we will run a number of experiments focused at improving the accuracy of a
14
# lightweight neural network, using a more powerful network as a teacher.
15
# The computational cost and the speed of the lightweight network will remain unaffected,
16
# our intervention only focuses on its weights, not on its forward pass.
17
# Applications of this technology can be found in devices such as drones or mobile phones.
18
# In this tutorial, we do not use any external packages as everything we need is available in ``torch`` and
19
# ``torchvision``.
20
#
21
# In this tutorial, you will learn:
22
#
23
# - How to modify model classes to extract hidden representations and use them for further calculations
24
# - How to modify regular train loops in PyTorch to include additional losses on top of, for example, cross-entropy for classification
25
# - How to improve the performance of lightweight models by using more complex models as teachers
26
#
27
# Prerequisites
28
# ~~~~~~~~~~~~~
29
#
30
# * 1 GPU, 4GB of memory
31
# * PyTorch v2.0 or later
32
# * CIFAR-10 dataset (downloaded by the script and saved in a directory called ``/data``)
33
34
import torch
35
import torch.nn as nn
36
import torch.optim as optim
37
import torchvision.transforms as transforms
38
import torchvision.datasets as datasets
39
40
# Check if the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
41
# is available, and if not, use the CPU
42
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
43
print(f"Using {device} device")
44
45
######################################################################
46
# Loading CIFAR-10
47
# ----------------
48
# CIFAR-10 is a popular image dataset with ten classes. Our objective is to predict one of the following classes for each input image.
49
#
50
# .. figure:: /../_static/img/cifar10.png
51
# :align: center
52
#
53
# Example of CIFAR-10 images
54
#
55
# The input images are RGB, so they have 3 channels and are 32x32 pixels. Basically, each image is described by 3 x 32 x 32 = 3072 numbers ranging from 0 to 255.
56
# A common practice in neural networks is to normalize the input, which is done for multiple reasons,
57
# including avoiding saturation in commonly used activation functions and increasing numerical stability.
58
# Our normalization process consists of subtracting the mean and dividing by the standard deviation along each channel.
59
# The tensors "mean=[0.485, 0.456, 0.406]" and "std=[0.229, 0.224, 0.225]" were already computed,
60
# and they represent the mean and standard deviation of each channel in the
61
# predefined subset of CIFAR-10 intended to be the training set.
62
# Notice how we use these values for the test set as well, without recomputing the mean and standard deviation from scratch.
63
# This is because the network was trained on features produced by subtracting and dividing the numbers above, and we want to maintain consistency.
64
# Furthermore, in real life, we would not be able to compute the mean and standard deviation of the test set since,
65
# under our assumptions, this data would not be accessible at that point.
66
#
67
# As a closing point, we often refer to this held-out set as the validation set, and we use a separate set,
68
# called the test set, after optimizing a model's performance on the validation set.
69
# This is done to avoid selecting a model based on the greedy and biased optimization of a single metric.
70
71
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
72
transforms_cifar = transforms.Compose([
73
transforms.ToTensor(),
74
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
75
])
76
77
# Loading the CIFAR-10 dataset:
78
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
79
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
80
81
########################################################################
82
# .. note:: This section is for CPU users only who are interested in quick results. Use this option only if you're interested in a small scale experiment. Keep in mind the code should run fairly quickly using any GPU. Select only the first ``num_images_to_keep`` images from the train/test dataset
83
#
84
# .. code-block:: python
85
#
86
# #from torch.utils.data import Subset
87
# #num_images_to_keep = 2000
88
# #train_dataset = Subset(train_dataset, range(min(num_images_to_keep, 50_000)))
89
# #test_dataset = Subset(test_dataset, range(min(num_images_to_keep, 10_000)))
90
91
#Dataloaders
92
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
93
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
94
95
######################################################################
96
# Defining model classes and utility functions
97
# --------------------------------------------
98
# Next, we need to define our model classes. Several user-defined parameters need to be set here. We use two different architectures, keeping the number of filters fixed across our experiments to ensure fair comparisons.
99
# Both architectures are Convolutional Neural Networks (CNNs) with a different number of convolutional layers that serve as feature extractors, followed by a classifier with 10 classes.
100
# The number of filters and neurons is smaller for the students.
101
102
# Deeper neural network class to be used as teacher:
103
class DeepNN(nn.Module):
104
def __init__(self, num_classes=10):
105
super(DeepNN, self).__init__()
106
self.features = nn.Sequential(
107
nn.Conv2d(3, 128, kernel_size=3, padding=1),
108
nn.ReLU(),
109
nn.Conv2d(128, 64, kernel_size=3, padding=1),
110
nn.ReLU(),
111
nn.MaxPool2d(kernel_size=2, stride=2),
112
nn.Conv2d(64, 64, kernel_size=3, padding=1),
113
nn.ReLU(),
114
nn.Conv2d(64, 32, kernel_size=3, padding=1),
115
nn.ReLU(),
116
nn.MaxPool2d(kernel_size=2, stride=2),
117
)
118
self.classifier = nn.Sequential(
119
nn.Linear(2048, 512),
120
nn.ReLU(),
121
nn.Dropout(0.1),
122
nn.Linear(512, num_classes)
123
)
124
125
def forward(self, x):
126
x = self.features(x)
127
x = torch.flatten(x, 1)
128
x = self.classifier(x)
129
return x
130
131
# Lightweight neural network class to be used as student:
132
class LightNN(nn.Module):
133
def __init__(self, num_classes=10):
134
super(LightNN, self).__init__()
135
self.features = nn.Sequential(
136
nn.Conv2d(3, 16, kernel_size=3, padding=1),
137
nn.ReLU(),
138
nn.MaxPool2d(kernel_size=2, stride=2),
139
nn.Conv2d(16, 16, kernel_size=3, padding=1),
140
nn.ReLU(),
141
nn.MaxPool2d(kernel_size=2, stride=2),
142
)
143
self.classifier = nn.Sequential(
144
nn.Linear(1024, 256),
145
nn.ReLU(),
146
nn.Dropout(0.1),
147
nn.Linear(256, num_classes)
148
)
149
150
def forward(self, x):
151
x = self.features(x)
152
x = torch.flatten(x, 1)
153
x = self.classifier(x)
154
return x
155
156
######################################################################
157
# We employ 2 functions to help us produce and evaluate the results on our original classification task.
158
# One function is called ``train`` and takes the following arguments:
159
#
160
# - ``model``: A model instance to train (update its weights) via this function.
161
# - ``train_loader``: We defined our ``train_loader`` above, and its job is to feed the data into the model.
162
# - ``epochs``: How many times we loop over the dataset.
163
# - ``learning_rate``: The learning rate determines how large our steps towards convergence should be. Too large or too small steps can be detrimental.
164
# - ``device``: Determines the device to run the workload on. Can be either CPU or GPU depending on availability.
165
#
166
# Our test function is similar, but it will be invoked with ``test_loader`` to load images from the test set.
167
#
168
# .. figure:: /../_static/img/knowledge_distillation/ce_only.png
169
# :align: center
170
#
171
# Train both networks with Cross-Entropy. The student will be used as a baseline:
172
#
173
174
def train(model, train_loader, epochs, learning_rate, device):
175
criterion = nn.CrossEntropyLoss()
176
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
177
178
model.train()
179
180
for epoch in range(epochs):
181
running_loss = 0.0
182
for inputs, labels in train_loader:
183
# inputs: A collection of batch_size images
184
# labels: A vector of dimensionality batch_size with integers denoting class of each image
185
inputs, labels = inputs.to(device), labels.to(device)
186
187
optimizer.zero_grad()
188
outputs = model(inputs)
189
190
# outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
191
# labels: The actual labels of the images. Vector of dimensionality batch_size
192
loss = criterion(outputs, labels)
193
loss.backward()
194
optimizer.step()
195
196
running_loss += loss.item()
197
198
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
199
200
def test(model, test_loader, device):
201
model.to(device)
202
model.eval()
203
204
correct = 0
205
total = 0
206
207
with torch.no_grad():
208
for inputs, labels in test_loader:
209
inputs, labels = inputs.to(device), labels.to(device)
210
211
outputs = model(inputs)
212
_, predicted = torch.max(outputs.data, 1)
213
214
total += labels.size(0)
215
correct += (predicted == labels).sum().item()
216
217
accuracy = 100 * correct / total
218
print(f"Test Accuracy: {accuracy:.2f}%")
219
return accuracy
220
221
######################################################################
222
# Cross-entropy runs
223
# ------------------
224
# For reproducibility, we need to set the torch manual seed. We train networks using different methods, so to compare them fairly,
225
# it makes sense to initialize the networks with the same weights.
226
# Start by training the teacher network using cross-entropy:
227
228
torch.manual_seed(42)
229
nn_deep = DeepNN(num_classes=10).to(device)
230
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
231
test_accuracy_deep = test(nn_deep, test_loader, device)
232
233
# Instantiate the lightweight network:
234
torch.manual_seed(42)
235
nn_light = LightNN(num_classes=10).to(device)
236
237
######################################################################
238
# We instantiate one more lightweight network model to compare their performances.
239
# Back propagation is sensitive to weight initialization,
240
# so we need to make sure these two networks have the exact same initialization.
241
242
torch.manual_seed(42)
243
new_nn_light = LightNN(num_classes=10).to(device)
244
245
######################################################################
246
# To ensure we have created a copy of the first network, we inspect the norm of its first layer.
247
# If it matches, then we are safe to conclude that the networks are indeed the same.
248
249
# Print the norm of the first layer of the initial lightweight model
250
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
251
# Print the norm of the first layer of the new lightweight model
252
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
253
254
######################################################################
255
# Print the total number of parameters in each model:
256
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
257
print(f"DeepNN parameters: {total_params_deep}")
258
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
259
print(f"LightNN parameters: {total_params_light}")
260
261
######################################################################
262
# Train and test the lightweight network with cross entropy loss:
263
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
264
test_accuracy_light_ce = test(nn_light, test_loader, device)
265
266
######################################################################
267
# As we can see, based on test accuracy, we can now compare the deeper network that is to be used as a teacher with the lightweight network that is our supposed student. So far, our student has not intervened with the teacher, therefore this performance is achieved by the student itself.
268
# The metrics so far can be seen with the following lines:
269
270
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
271
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")
272
273
######################################################################
274
# Knowledge distillation run
275
# --------------------------
276
# Now let's try to improve the test accuracy of the student network by incorporating the teacher.
277
# Knowledge distillation is a straightforward technique to achieve this,
278
# based on the fact that both networks output a probability distribution over our classes.
279
# Therefore, the two networks share the same number of output neurons.
280
# The method works by incorporating an additional loss into the traditional cross entropy loss,
281
# which is based on the softmax output of the teacher network.
282
# The assumption is that the output activations of a properly trained teacher network carry additional information that can be leveraged by a student network during training.
283
# The original work suggests that utilizing ratios of smaller probabilities in the soft targets can help achieve the underlying objective of deep neural networks,
284
# which is to create a similarity structure over the data where similar objects are mapped closer together.
285
# For example, in CIFAR-10, a truck could be mistaken for an automobile or airplane,
286
# if its wheels are present, but it is less likely to be mistaken for a dog.
287
# Therefore, it makes sense to assume that valuable information resides not only in the top prediction of a properly trained model but in the entire output distribution.
288
# However, cross entropy alone does not sufficiently exploit this information as the activations for non-predicted classes
289
# tend to be so small that propagated gradients do not meaningfully change the weights to construct this desirable vector space.
290
#
291
# As we continue defining our first helper function that introduces a teacher-student dynamic, we need to include a few extra parameters:
292
#
293
# - ``T``: Temperature controls the smoothness of the output distributions. Larger ``T`` leads to smoother distributions, thus smaller probabilities get a larger boost.
294
# - ``soft_target_loss_weight``: A weight assigned to the extra objective we're about to include.
295
# - ``ce_loss_weight``: A weight assigned to cross-entropy. Tuning these weights pushes the network towards optimizing for either objective.
296
#
297
# .. figure:: /../_static/img/knowledge_distillation/distillation_output_loss.png
298
# :align: center
299
#
300
# Distillation loss is calculated from the logits of the networks. It only returns gradients to the student:
301
#
302
303
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
304
ce_loss = nn.CrossEntropyLoss()
305
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
306
307
teacher.eval() # Teacher set to evaluation mode
308
student.train() # Student to train mode
309
310
for epoch in range(epochs):
311
running_loss = 0.0
312
for inputs, labels in train_loader:
313
inputs, labels = inputs.to(device), labels.to(device)
314
315
optimizer.zero_grad()
316
317
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
318
with torch.no_grad():
319
teacher_logits = teacher(inputs)
320
321
# Forward pass with the student model
322
student_logits = student(inputs)
323
324
#Soften the student logits by applying softmax first and log() second
325
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
326
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
327
328
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
329
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
330
331
# Calculate the true label loss
332
label_loss = ce_loss(student_logits, labels)
333
334
# Weighted sum of the two losses
335
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
336
337
loss.backward()
338
optimizer.step()
339
340
running_loss += loss.item()
341
342
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
343
344
# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
345
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
346
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)
347
348
# Compare the student test accuracy with and without the teacher, after distillation
349
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
350
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
351
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
352
353
######################################################################
354
# Cosine loss minimization run
355
# ----------------------------
356
# Feel free to play around with the temperature parameter that controls the softness of the softmax function and the loss coefficients.
357
# In neural networks, it is easy to include additional loss functions to the main objectives to achieve goals like better generalization.
358
# Let's try including an objective for the student, but now let's focus on their hidden states rather than their output layers.
359
# Our goal is to convey information from the teacher's representation to the student by including a naive loss function,
360
# whose minimization implies that the flattened vectors that are subsequently passed to the classifiers have become more *similar* as the loss decreases.
361
# Of course, the teacher does not update its weights, so the minimization depends only on the student's weights.
362
# The rationale behind this method is that we are operating under the assumption that the teacher model has a better internal representation that is
363
# unlikely to be achieved by the student without external intervention, therefore we artificially push the student to mimic the internal representation of the teacher.
364
# Whether or not this will end up helping the student is not straightforward, though, because pushing the lightweight network
365
# to reach this point could be a good thing, assuming that we have found an internal representation that leads to better test accuracy,
366
# but it could also be harmful because the networks have different architectures and the student does not have the same learning capacity as the teacher.
367
# In other words, there is no reason for these two vectors, the student's and the teacher's to match per component.
368
# The student could reach an internal representation that is a permutation of the teacher's and it would be just as efficient.
369
# Nonetheless, we can still run a quick experiment to figure out the impact of this method.
370
# We will be using the ``CosineEmbeddingLoss`` which is given by the following formula:
371
#
372
# .. figure:: /../_static/img/knowledge_distillation/cosine_embedding_loss.png
373
# :align: center
374
# :width: 450px
375
#
376
# Formula for CosineEmbeddingLoss
377
#
378
# Obviously, there is one thing that we need to resolve first.
379
# When we applied distillation to the output layer we mentioned that both networks have the same number of neurons, equal to the number of classes.
380
# However, this is not the case for the layer following our convolutional layers. Here, the teacher has more neurons than the student
381
# after the flattening of the final convolutional layer. Our loss function accepts two vectors of equal dimensionality as inputs,
382
# therefore we need to somehow match them. We will solve this by including an average pooling layer after the teacher's convolutional layer to reduce its dimensionality to match that of the student.
383
#
384
# To proceed, we will modify our model classes, or create new ones.
385
# Now, the forward function returns not only the logits of the network but also the flattened hidden representation after the convolutional layer. We include the aforementioned pooling for the modified teacher.
386
387
class ModifiedDeepNNCosine(nn.Module):
388
def __init__(self, num_classes=10):
389
super(ModifiedDeepNNCosine, self).__init__()
390
self.features = nn.Sequential(
391
nn.Conv2d(3, 128, kernel_size=3, padding=1),
392
nn.ReLU(),
393
nn.Conv2d(128, 64, kernel_size=3, padding=1),
394
nn.ReLU(),
395
nn.MaxPool2d(kernel_size=2, stride=2),
396
nn.Conv2d(64, 64, kernel_size=3, padding=1),
397
nn.ReLU(),
398
nn.Conv2d(64, 32, kernel_size=3, padding=1),
399
nn.ReLU(),
400
nn.MaxPool2d(kernel_size=2, stride=2),
401
)
402
self.classifier = nn.Sequential(
403
nn.Linear(2048, 512),
404
nn.ReLU(),
405
nn.Dropout(0.1),
406
nn.Linear(512, num_classes)
407
)
408
409
def forward(self, x):
410
x = self.features(x)
411
flattened_conv_output = torch.flatten(x, 1)
412
x = self.classifier(flattened_conv_output)
413
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
414
return x, flattened_conv_output_after_pooling
415
416
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
417
class ModifiedLightNNCosine(nn.Module):
418
def __init__(self, num_classes=10):
419
super(ModifiedLightNNCosine, self).__init__()
420
self.features = nn.Sequential(
421
nn.Conv2d(3, 16, kernel_size=3, padding=1),
422
nn.ReLU(),
423
nn.MaxPool2d(kernel_size=2, stride=2),
424
nn.Conv2d(16, 16, kernel_size=3, padding=1),
425
nn.ReLU(),
426
nn.MaxPool2d(kernel_size=2, stride=2),
427
)
428
self.classifier = nn.Sequential(
429
nn.Linear(1024, 256),
430
nn.ReLU(),
431
nn.Dropout(0.1),
432
nn.Linear(256, num_classes)
433
)
434
435
def forward(self, x):
436
x = self.features(x)
437
flattened_conv_output = torch.flatten(x, 1)
438
x = self.classifier(flattened_conv_output)
439
return x, flattened_conv_output
440
441
# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
442
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
443
modified_nn_deep.load_state_dict(nn_deep.state_dict())
444
445
# Once again ensure the norm of the first layer is the same for both networks
446
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
447
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())
448
449
# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
450
torch.manual_seed(42)
451
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
452
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())
453
454
######################################################################
455
# Naturally, we need to change the train loop because now the model returns a tuple ``(logits, hidden_representation)``. Using a sample input tensor
456
# we can print their shapes.
457
458
# Create a sample input tensor
459
sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32
460
461
# Pass the input through the student
462
logits, hidden_representation = modified_nn_light(sample_input)
463
464
# Print the shapes of the tensors
465
print("Student logits shape:", logits.shape) # batch_size x total_classes
466
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size
467
468
# Pass the input through the teacher
469
logits, hidden_representation = modified_nn_deep(sample_input)
470
471
# Print the shapes of the tensors
472
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
473
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size
474
475
######################################################################
476
# In our case, ``hidden_representation_size`` is ``1024``. This is the flattened feature map of the final convolutional layer of the student and as you can see,
477
# it is the input for its classifier. It is ``1024`` for the teacher too, because we made it so with ``avg_pool1d`` from ``2048``.
478
# The loss applied here only affects the weights of the student prior to the loss calculation. In other words, it does not affect the classifier of the student.
479
# The modified training loop is the following:
480
#
481
# .. figure:: /../_static/img/knowledge_distillation/cosine_loss_distillation.png
482
# :align: center
483
#
484
# In Cosine Loss minimization, we want to maximize the cosine similarity of the two representations by returning gradients to the student:
485
#
486
487
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
488
ce_loss = nn.CrossEntropyLoss()
489
cosine_loss = nn.CosineEmbeddingLoss()
490
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
491
492
teacher.to(device)
493
student.to(device)
494
teacher.eval() # Teacher set to evaluation mode
495
student.train() # Student to train mode
496
497
for epoch in range(epochs):
498
running_loss = 0.0
499
for inputs, labels in train_loader:
500
inputs, labels = inputs.to(device), labels.to(device)
501
502
optimizer.zero_grad()
503
504
# Forward pass with the teacher model and keep only the hidden representation
505
with torch.no_grad():
506
_, teacher_hidden_representation = teacher(inputs)
507
508
# Forward pass with the student model
509
student_logits, student_hidden_representation = student(inputs)
510
511
# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
512
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
513
514
# Calculate the true label loss
515
label_loss = ce_loss(student_logits, labels)
516
517
# Weighted sum of the two losses
518
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
519
520
loss.backward()
521
optimizer.step()
522
523
running_loss += loss.item()
524
525
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
526
527
######################################################################
528
#We need to modify our test function for the same reason. Here we ignore the hidden representation returned by the model.
529
530
def test_multiple_outputs(model, test_loader, device):
531
model.to(device)
532
model.eval()
533
534
correct = 0
535
total = 0
536
537
with torch.no_grad():
538
for inputs, labels in test_loader:
539
inputs, labels = inputs.to(device), labels.to(device)
540
541
outputs, _ = model(inputs) # Disregard the second tensor of the tuple
542
_, predicted = torch.max(outputs.data, 1)
543
544
total += labels.size(0)
545
correct += (predicted == labels).sum().item()
546
547
accuracy = 100 * correct / total
548
print(f"Test Accuracy: {accuracy:.2f}%")
549
return accuracy
550
551
######################################################################
552
# In this case, we could easily include both knowledge distillation and cosine loss minimization in the same function. It is common to combine methods to achieve better performance in teacher-student paradigms.
553
# For now, we can run a simple train-test session.
554
555
# Train and test the lightweight network with cross entropy loss
556
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
557
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
558
559
######################################################################
560
# Intermediate regressor run
561
# --------------------------
562
# Our naive minimization does not guarantee better results for several reasons, one being the dimensionality of the vectors.
563
# Cosine similarity generally works better than Euclidean distance for vectors of higher dimensionality,
564
# but we were dealing with vectors with 1024 components each, so it is much harder to extract meaningful similarities.
565
# Furthermore, as we mentioned, pushing towards a match of the hidden representation of the teacher and the student is not supported by theory.
566
# There are no good reasons why we should be aiming for a 1:1 match of these vectors.
567
# We will provide a final example of training intervention by including an extra network called regressor.
568
# The objective is to first extract the feature map of the teacher after a convolutional layer,
569
# then extract a feature map of the student after a convolutional layer, and finally try to match these maps.
570
# However, this time, we will introduce a regressor between the networks to facilitate the matching process.
571
# The regressor will be trainable and ideally will do a better job than our naive cosine loss minimization scheme.
572
# Its main job is to match the dimensionality of these feature maps so that we can properly define a loss function between the teacher and the student.
573
# Defining such a loss function provides a teaching "path," which is basically a flow to back-propagate gradients that will change the student's weights.
574
# Focusing on the output of the convolutional layers right before each classifier for our original networks, we have the following shapes:
575
#
576
577
# Pass the sample input only from the convolutional feature extractor
578
convolutional_fe_output_student = nn_light.features(sample_input)
579
convolutional_fe_output_teacher = nn_deep.features(sample_input)
580
581
# Print their shapes
582
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
583
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)
584
585
######################################################################
586
# We have 32 filters for the teacher and 16 filters for the student.
587
# We will include a trainable layer that converts the feature map of the student to the shape of the feature map of the teacher.
588
# In practice, we modify the lightweight class to return the hidden state after an intermediate regressor that matches the sizes of the convolutional
589
# feature maps and the teacher class to return the output of the final convolutional layer without pooling or flattening.
590
#
591
# .. figure:: /../_static/img/knowledge_distillation/fitnets_knowledge_distill.png
592
# :align: center
593
#
594
# The trainable layer matches the shapes of the intermediate tensors and Mean Squared Error (MSE) is properly defined:
595
#
596
597
class ModifiedDeepNNRegressor(nn.Module):
598
def __init__(self, num_classes=10):
599
super(ModifiedDeepNNRegressor, self).__init__()
600
self.features = nn.Sequential(
601
nn.Conv2d(3, 128, kernel_size=3, padding=1),
602
nn.ReLU(),
603
nn.Conv2d(128, 64, kernel_size=3, padding=1),
604
nn.ReLU(),
605
nn.MaxPool2d(kernel_size=2, stride=2),
606
nn.Conv2d(64, 64, kernel_size=3, padding=1),
607
nn.ReLU(),
608
nn.Conv2d(64, 32, kernel_size=3, padding=1),
609
nn.ReLU(),
610
nn.MaxPool2d(kernel_size=2, stride=2),
611
)
612
self.classifier = nn.Sequential(
613
nn.Linear(2048, 512),
614
nn.ReLU(),
615
nn.Dropout(0.1),
616
nn.Linear(512, num_classes)
617
)
618
619
def forward(self, x):
620
x = self.features(x)
621
conv_feature_map = x
622
x = torch.flatten(x, 1)
623
x = self.classifier(x)
624
return x, conv_feature_map
625
626
class ModifiedLightNNRegressor(nn.Module):
627
def __init__(self, num_classes=10):
628
super(ModifiedLightNNRegressor, self).__init__()
629
self.features = nn.Sequential(
630
nn.Conv2d(3, 16, kernel_size=3, padding=1),
631
nn.ReLU(),
632
nn.MaxPool2d(kernel_size=2, stride=2),
633
nn.Conv2d(16, 16, kernel_size=3, padding=1),
634
nn.ReLU(),
635
nn.MaxPool2d(kernel_size=2, stride=2),
636
)
637
# Include an extra regressor (in our case linear)
638
self.regressor = nn.Sequential(
639
nn.Conv2d(16, 32, kernel_size=3, padding=1)
640
)
641
self.classifier = nn.Sequential(
642
nn.Linear(1024, 256),
643
nn.ReLU(),
644
nn.Dropout(0.1),
645
nn.Linear(256, num_classes)
646
)
647
648
def forward(self, x):
649
x = self.features(x)
650
regressor_output = self.regressor(x)
651
x = torch.flatten(x, 1)
652
x = self.classifier(x)
653
return x, regressor_output
654
655
######################################################################
656
# After that, we have to update our train loop again. This time, we extract the regressor output of the student, the feature map of the teacher,
657
# we calculate the ``MSE`` on these tensors (they have the exact same shape so it's properly defined) and we back propagate gradients based on that loss,
658
# in addition to the regular cross entropy loss of the classification task.
659
660
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
661
ce_loss = nn.CrossEntropyLoss()
662
mse_loss = nn.MSELoss()
663
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
664
665
teacher.to(device)
666
student.to(device)
667
teacher.eval() # Teacher set to evaluation mode
668
student.train() # Student to train mode
669
670
for epoch in range(epochs):
671
running_loss = 0.0
672
for inputs, labels in train_loader:
673
inputs, labels = inputs.to(device), labels.to(device)
674
675
optimizer.zero_grad()
676
677
# Again ignore teacher logits
678
with torch.no_grad():
679
_, teacher_feature_map = teacher(inputs)
680
681
# Forward pass with the student model
682
student_logits, regressor_feature_map = student(inputs)
683
684
# Calculate the loss
685
hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)
686
687
# Calculate the true label loss
688
label_loss = ce_loss(student_logits, labels)
689
690
# Weighted sum of the two losses
691
loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss
692
693
loss.backward()
694
optimizer.step()
695
696
running_loss += loss.item()
697
698
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
699
700
# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.
701
702
# Initialize a ModifiedLightNNRegressor
703
torch.manual_seed(42)
704
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)
705
706
# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
707
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
708
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())
709
710
# Train and test once again
711
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
712
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
713
714
######################################################################
715
# It is expected that the final method will work better than ``CosineLoss`` because now we have allowed a trainable layer between the teacher and the student,
716
# which gives the student some wiggle room when it comes to learning, rather than pushing the student to copy the teacher's representation.
717
# Including the extra network is the idea behind hint-based distillation.
718
719
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
720
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
721
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
722
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%")
723
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%")
724
725
######################################################################
726
# Conclusion
727
# --------------------------------------------
728
# None of the methods above increases the number of parameters for the network or inference time,
729
# so the performance increase comes at the little cost of calculating gradients during training.
730
# In ML applications, we mostly care about inference time because training happens before the model deployment.
731
# If our lightweight model is still too heavy for deployment, we can apply different ideas, such as post-training quantization.
732
# Additional losses can be applied in many tasks, not just classification, and you can experiment with quantities like coefficients,
733
# temperature, or number of neurons. Feel free to tune any numbers in the tutorial above,
734
# but keep in mind, if you change the number of neurons / filters chances are a shape mismatch might occur.
735
#
736
# For more information, see:
737
#
738
# * `Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. In: Neural Information Processing System Deep Learning Workshop (2015) <https://arxiv.org/abs/1503.02531>`_
739
#
740
# * `Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of the International Conference on Learning Representations (2015) <https://arxiv.org/abs/1412.6550>`_
741
742