Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/introyt/trainingyt.py
6815 views
1
"""
2
`Introduction <introyt1_tutorial.html>`_ ||
3
`Tensors <tensors_deeper_tutorial.html>`_ ||
4
`Autograd <autogradyt_tutorial.html>`_ ||
5
`Building Models <modelsyt_tutorial.html>`_ ||
6
`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||
7
**Training Models** ||
8
`Model Understanding <captumyt.html>`_
9
10
Training with PyTorch
11
=====================
12
13
Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=jF43_wj_DCQ>`__.
14
15
.. raw:: html
16
17
<div style="margin-top:10px; margin-bottom:10px;">
18
<iframe width="560" height="315" src="https://www.youtube.com/embed/jF43_wj_DCQ" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
19
</div>
20
21
Introduction
22
------------
23
24
In past videos, we’ve discussed and demonstrated:
25
26
- Building models with the neural network layers and functions of the torch.nn module
27
- The mechanics of automated gradient computation, which is central to
28
gradient-based model training
29
- Using TensorBoard to visualize training progress and other activities
30
31
In this video, we’ll be adding some new tools to your inventory:
32
33
- We’ll get familiar with the dataset and dataloader abstractions, and how
34
they ease the process of feeding data to your model during a training loop
35
- We’ll discuss specific loss functions and when to use them
36
- We’ll look at PyTorch optimizers, which implement algorithms to adjust
37
model weights based on the outcome of a loss function
38
39
Finally, we’ll pull all of these together and see a full PyTorch
40
training loop in action.
41
42
43
Dataset and DataLoader
44
----------------------
45
46
The ``Dataset`` and ``DataLoader`` classes encapsulate the process of
47
pulling your data from storage and exposing it to your training loop in
48
batches.
49
50
The ``Dataset`` is responsible for accessing and processing single
51
instances of data.
52
53
The ``DataLoader`` pulls instances of data from the ``Dataset`` (either
54
automatically or with a sampler that you define), collects them in
55
batches, and returns them for consumption by your training loop. The
56
``DataLoader`` works with all kinds of datasets, regardless of the type
57
of data they contain.
58
59
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
60
TorchVision. We use ``torchvision.transforms.v2.Normalize()`` to
61
zero-center and normalize the distribution of the image tile content,
62
and download both training and validation data splits.
63
64
"""
65
66
import torch
67
import torchvision
68
from torchvision.transforms import v2
69
70
# PyTorch TensorBoard support
71
from torch.utils.tensorboard import SummaryWriter
72
from datetime import datetime
73
74
75
transform = v2.Compose([
76
v2.ToImage(),
77
v2.ToDtype(torch.float32, scale=True),
78
v2.Normalize((0.5,), (0.5,))
79
])
80
81
# Create datasets for training & validation, download if necessary
82
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
83
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
84
85
# Create data loaders for our datasets; shuffle for training, not for validation
86
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
87
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
88
89
# Class labels
90
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
91
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
92
93
# Report split sizes
94
print(f'Training set has {len(training_set)} instances')
95
print(f'Validation set has {len(validation_set)} instances')
96
97
98
######################################################################
99
# As always, let’s visualize the data as a sanity check:
100
#
101
102
import matplotlib.pyplot as plt
103
import numpy as np
104
105
# Helper function for inline image display
106
def matplotlib_imshow(img, one_channel=False):
107
if one_channel:
108
img = img.mean(dim=0)
109
img = img / 2 + 0.5 # unnormalize
110
npimg = img.numpy()
111
if one_channel:
112
plt.imshow(npimg, cmap="Greys")
113
else:
114
plt.imshow(np.transpose(npimg, (1, 2, 0)))
115
116
dataiter = iter(training_loader)
117
images, labels = next(dataiter)
118
119
# Create a grid from the images and show them
120
img_grid = torchvision.utils.make_grid(images)
121
matplotlib_imshow(img_grid, one_channel=True)
122
print(' '.join(classes[labels[j]] for j in range(4)))
123
124
125
#########################################################################
126
# The Model
127
# ---------
128
#
129
# The model we’ll use in this example is a variant of LeNet-5 - it should
130
# be familiar if you’ve watched the previous videos in this series.
131
#
132
133
import torch.nn as nn
134
import torch.nn.functional as F
135
136
# PyTorch models inherit from torch.nn.Module
137
class GarmentClassifier(nn.Module):
138
def __init__(self):
139
super().__init__()
140
self.conv1 = nn.Conv2d(1, 6, 5)
141
self.pool = nn.MaxPool2d(2, 2)
142
self.conv2 = nn.Conv2d(6, 16, 5)
143
self.fc1 = nn.Linear(16 * 4 * 4, 120)
144
self.fc2 = nn.Linear(120, 84)
145
self.fc3 = nn.Linear(84, 10)
146
147
def forward(self, x):
148
x = self.pool(F.relu(self.conv1(x)))
149
x = self.pool(F.relu(self.conv2(x)))
150
x = x.view(-1, 16 * 4 * 4)
151
x = F.relu(self.fc1(x))
152
x = F.relu(self.fc2(x))
153
x = self.fc3(x)
154
return x
155
156
157
model = GarmentClassifier()
158
159
160
##########################################################################
161
# Loss Function
162
# -------------
163
#
164
# For this example, we’ll be using a cross-entropy loss. For demonstration
165
# purposes, we’ll create batches of dummy output and label values, run
166
# them through the loss function, and examine the result.
167
#
168
169
loss_fn = torch.nn.CrossEntropyLoss()
170
171
# NB: Loss functions expect data in batches, so we're creating batches of 4
172
# Represents the model's confidence in each of the 10 classes for a given input
173
dummy_outputs = torch.rand(4, 10)
174
# Represents the correct class among the 10 being tested
175
dummy_labels = torch.tensor([1, 5, 3, 7])
176
177
print(dummy_outputs)
178
print(dummy_labels)
179
180
loss = loss_fn(dummy_outputs, dummy_labels)
181
print(f'Total loss for this batch: {loss.item()}')
182
183
184
#################################################################################
185
# Optimizer
186
# ---------
187
#
188
# For this example, we’ll be using simple `stochastic gradient
189
# descent <https://pytorch.org/docs/stable/optim.html>`__ with momentum.
190
#
191
# It can be instructive to try some variations on this optimization
192
# scheme:
193
#
194
# - Learning rate determines the size of the steps the optimizer
195
# takes. What does a different learning rate do to the your training
196
# results, in terms of accuracy and convergence time?
197
# - Momentum nudges the optimizer in the direction of strongest gradient over
198
# multiple steps. What does changing this value do to your results?
199
# - Try some different optimization algorithms, such as averaged SGD, Adagrad, or
200
# Adam. How do your results differ?
201
#
202
203
# Optimizers specified in the torch.optim package
204
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
205
206
207
#######################################################################################
208
# The Training Loop
209
# -----------------
210
#
211
# Below, we have a function that performs one training epoch. It
212
# enumerates data from the DataLoader, and on each pass of the loop does
213
# the following:
214
#
215
# - Gets a batch of training data from the DataLoader
216
# - Zeros the optimizer’s gradients
217
# - Performs an inference - that is, gets predictions from the model for an input batch
218
# - Calculates the loss for that set of predictions vs. the labels on the dataset
219
# - Calculates the backward gradients over the learning weights
220
# - Tells the optimizer to perform one learning step - that is, adjust the model’s
221
# learning weights based on the observed gradients for this batch, according to the
222
# optimization algorithm we chose
223
# - It reports on the loss for every 1000 batches.
224
# - Finally, it reports the average per-batch loss for the last
225
# 1000 batches, for comparison with a validation run
226
#
227
228
def train_one_epoch(epoch_index, tb_writer):
229
running_loss = 0.
230
last_loss = 0.
231
232
# Here, we use enumerate(training_loader) instead of
233
# iter(training_loader) so that we can track the batch
234
# index and do some intra-epoch reporting
235
for i, data in enumerate(training_loader):
236
# Every data instance is an input + label pair
237
inputs, labels = data
238
239
# Zero your gradients for every batch!
240
optimizer.zero_grad()
241
242
# Make predictions for this batch
243
outputs = model(inputs)
244
245
# Compute the loss and its gradients
246
loss = loss_fn(outputs, labels)
247
loss.backward()
248
249
# Adjust learning weights
250
optimizer.step()
251
252
# Gather data and report
253
running_loss += loss.item()
254
if i % 1000 == 999:
255
last_loss = running_loss / 1000 # loss per batch
256
print(f' batch {i + 1} loss: {last_loss}')
257
tb_x = epoch_index * len(training_loader) + i + 1
258
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
259
running_loss = 0.
260
261
return last_loss
262
263
264
##################################################################################
265
# Per-Epoch Activity
266
# ~~~~~~~~~~~~~~~~~~
267
#
268
# There are a couple of things we’ll want to do once per epoch:
269
#
270
# - Perform validation by checking our relative loss on a set of data that was not
271
# used for training, and report this
272
# - Save a copy of the model
273
#
274
# Here, we’ll do our reporting in TensorBoard. This will require going to
275
# the command line to start TensorBoard, and opening it in another browser
276
# tab.
277
#
278
279
# Initializing in a separate cell so we can easily add more epochs to the same run
280
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
281
writer = SummaryWriter(f'runs/fashion_trainer_{timestamp}')
282
epoch_number = 0
283
284
EPOCHS = 5
285
286
best_vloss = 1_000_000.
287
288
for epoch in range(EPOCHS):
289
print(f'EPOCH {epoch_number + 1}:')
290
291
# Make sure gradient tracking is on, and do a pass over the data
292
model.train(True)
293
avg_loss = train_one_epoch(epoch_number, writer)
294
295
296
running_vloss = 0.0
297
# Set the model to evaluation mode, disabling dropout and using population
298
# statistics for batch normalization.
299
model.eval()
300
301
# Disable gradient computation and reduce memory consumption.
302
with torch.no_grad():
303
for i, vdata in enumerate(validation_loader):
304
vinputs, vlabels = vdata
305
voutputs = model(vinputs)
306
vloss = loss_fn(voutputs, vlabels)
307
running_vloss += vloss
308
309
avg_vloss = running_vloss / (i + 1)
310
print(f'LOSS train {avg_loss} valid {avg_vloss}')
311
312
# Log the running loss averaged per batch
313
# for both training and validation
314
writer.add_scalars('Training vs. Validation Loss',
315
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
316
epoch_number + 1)
317
writer.flush()
318
319
# Track best performance, and save the model's state
320
if avg_vloss < best_vloss:
321
best_vloss = avg_vloss
322
model_path = f'model_{timestamp}_{epoch_number}'
323
torch.save(model.state_dict(), model_path)
324
325
epoch_number += 1
326
327
328
#########################################################################
329
# To load a saved version of the model:
330
#
331
# .. code:: python
332
#
333
# saved_model = GarmentClassifier()
334
# saved_model.load_state_dict(torch.load(PATH))
335
#
336
# Once you’ve loaded the model, it’s ready for whatever you need it for -
337
# more training, inference, or analysis.
338
#
339
# Note that if your model has constructor parameters that affect model
340
# structure, you’ll need to provide them and configure the model
341
# identically to the state in which it was saved.
342
#
343
# Other Resources
344
# ---------------
345
#
346
# - Docs on the `data
347
# utilities <https://pytorch.org/docs/stable/data.html>`__, including
348
# Dataset and DataLoader, at pytorch.org
349
# - A `note on the use of pinned
350
# memory <https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning>`__
351
# for GPU training
352
# - Documentation on the datasets available in
353
# `TorchVision <https://pytorch.org/vision/stable/datasets.html>`__,
354
# `TorchText <https://pytorch.org/text/stable/datasets.html>`__, and
355
# `TorchAudio <https://pytorch.org/audio/stable/datasets.html>`__
356
# - Documentation on the `loss
357
# functions <https://pytorch.org/docs/stable/nn.html#loss-functions>`__
358
# available in PyTorch
359
# - Documentation on the `torch.optim
360
# package <https://pytorch.org/docs/stable/optim.html>`__, which
361
# includes optimizers and related tools, such as learning rate
362
# scheduling
363
# - A detailed `tutorial on saving and loading
364
# models <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__
365
# - The `Tutorials section of
366
# pytorch.org <https://pytorch.org/tutorials/>`__ contains tutorials on
367
# a broad variety of training tasks, including classification in
368
# different domains, generative adversarial networks, reinforcement
369
# learning, and more
370
#
371
372