CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/introyt/trainingyt.py
Views: 713
1
"""
2
`Introduction <introyt1_tutorial.html>`_ ||
3
`Tensors <tensors_deeper_tutorial.html>`_ ||
4
`Autograd <autogradyt_tutorial.html>`_ ||
5
`Building Models <modelsyt_tutorial.html>`_ ||
6
`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||
7
**Training Models** ||
8
`Model Understanding <captumyt.html>`_
9
10
Training with PyTorch
11
=====================
12
13
Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=jF43_wj_DCQ>`__.
14
15
.. raw:: html
16
17
<div style="margin-top:10px; margin-bottom:10px;">
18
<iframe width="560" height="315" src="https://www.youtube.com/embed/jF43_wj_DCQ" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
19
</div>
20
21
Introduction
22
------------
23
24
In past videos, we’ve discussed and demonstrated:
25
26
- Building models with the neural network layers and functions of the torch.nn module
27
- The mechanics of automated gradient computation, which is central to
28
gradient-based model training
29
- Using TensorBoard to visualize training progress and other activities
30
31
In this video, we’ll be adding some new tools to your inventory:
32
33
- We’ll get familiar with the dataset and dataloader abstractions, and how
34
they ease the process of feeding data to your model during a training loop
35
- We’ll discuss specific loss functions and when to use them
36
- We’ll look at PyTorch optimizers, which implement algorithms to adjust
37
model weights based on the outcome of a loss function
38
39
Finally, we’ll pull all of these together and see a full PyTorch
40
training loop in action.
41
42
43
Dataset and DataLoader
44
----------------------
45
46
The ``Dataset`` and ``DataLoader`` classes encapsulate the process of
47
pulling your data from storage and exposing it to your training loop in
48
batches.
49
50
The ``Dataset`` is responsible for accessing and processing single
51
instances of data.
52
53
The ``DataLoader`` pulls instances of data from the ``Dataset`` (either
54
automatically or with a sampler that you define), collects them in
55
batches, and returns them for consumption by your training loop. The
56
``DataLoader`` works with all kinds of datasets, regardless of the type
57
of data they contain.
58
59
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
60
TorchVision. We use ``torchvision.transforms.Normalize()`` to
61
zero-center and normalize the distribution of the image tile content,
62
and download both training and validation data splits.
63
64
"""
65
66
import torch
67
import torchvision
68
import torchvision.transforms as transforms
69
70
# PyTorch TensorBoard support
71
from torch.utils.tensorboard import SummaryWriter
72
from datetime import datetime
73
74
75
transform = transforms.Compose(
76
[transforms.ToTensor(),
77
transforms.Normalize((0.5,), (0.5,))])
78
79
# Create datasets for training & validation, download if necessary
80
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
81
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
82
83
# Create data loaders for our datasets; shuffle for training, not for validation
84
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
85
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
86
87
# Class labels
88
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
89
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
90
91
# Report split sizes
92
print('Training set has {} instances'.format(len(training_set)))
93
print('Validation set has {} instances'.format(len(validation_set)))
94
95
96
######################################################################
97
# As always, let’s visualize the data as a sanity check:
98
#
99
100
import matplotlib.pyplot as plt
101
import numpy as np
102
103
# Helper function for inline image display
104
def matplotlib_imshow(img, one_channel=False):
105
if one_channel:
106
img = img.mean(dim=0)
107
img = img / 2 + 0.5 # unnormalize
108
npimg = img.numpy()
109
if one_channel:
110
plt.imshow(npimg, cmap="Greys")
111
else:
112
plt.imshow(np.transpose(npimg, (1, 2, 0)))
113
114
dataiter = iter(training_loader)
115
images, labels = next(dataiter)
116
117
# Create a grid from the images and show them
118
img_grid = torchvision.utils.make_grid(images)
119
matplotlib_imshow(img_grid, one_channel=True)
120
print(' '.join(classes[labels[j]] for j in range(4)))
121
122
123
#########################################################################
124
# The Model
125
# ---------
126
#
127
# The model we’ll use in this example is a variant of LeNet-5 - it should
128
# be familiar if you’ve watched the previous videos in this series.
129
#
130
131
import torch.nn as nn
132
import torch.nn.functional as F
133
134
# PyTorch models inherit from torch.nn.Module
135
class GarmentClassifier(nn.Module):
136
def __init__(self):
137
super(GarmentClassifier, self).__init__()
138
self.conv1 = nn.Conv2d(1, 6, 5)
139
self.pool = nn.MaxPool2d(2, 2)
140
self.conv2 = nn.Conv2d(6, 16, 5)
141
self.fc1 = nn.Linear(16 * 4 * 4, 120)
142
self.fc2 = nn.Linear(120, 84)
143
self.fc3 = nn.Linear(84, 10)
144
145
def forward(self, x):
146
x = self.pool(F.relu(self.conv1(x)))
147
x = self.pool(F.relu(self.conv2(x)))
148
x = x.view(-1, 16 * 4 * 4)
149
x = F.relu(self.fc1(x))
150
x = F.relu(self.fc2(x))
151
x = self.fc3(x)
152
return x
153
154
155
model = GarmentClassifier()
156
157
158
##########################################################################
159
# Loss Function
160
# -------------
161
#
162
# For this example, we’ll be using a cross-entropy loss. For demonstration
163
# purposes, we’ll create batches of dummy output and label values, run
164
# them through the loss function, and examine the result.
165
#
166
167
loss_fn = torch.nn.CrossEntropyLoss()
168
169
# NB: Loss functions expect data in batches, so we're creating batches of 4
170
# Represents the model's confidence in each of the 10 classes for a given input
171
dummy_outputs = torch.rand(4, 10)
172
# Represents the correct class among the 10 being tested
173
dummy_labels = torch.tensor([1, 5, 3, 7])
174
175
print(dummy_outputs)
176
print(dummy_labels)
177
178
loss = loss_fn(dummy_outputs, dummy_labels)
179
print('Total loss for this batch: {}'.format(loss.item()))
180
181
182
#################################################################################
183
# Optimizer
184
# ---------
185
#
186
# For this example, we’ll be using simple `stochastic gradient
187
# descent <https://pytorch.org/docs/stable/optim.html>`__ with momentum.
188
#
189
# It can be instructive to try some variations on this optimization
190
# scheme:
191
#
192
# - Learning rate determines the size of the steps the optimizer
193
# takes. What does a different learning rate do to the your training
194
# results, in terms of accuracy and convergence time?
195
# - Momentum nudges the optimizer in the direction of strongest gradient over
196
# multiple steps. What does changing this value do to your results?
197
# - Try some different optimization algorithms, such as averaged SGD, Adagrad, or
198
# Adam. How do your results differ?
199
#
200
201
# Optimizers specified in the torch.optim package
202
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
203
204
205
#######################################################################################
206
# The Training Loop
207
# -----------------
208
#
209
# Below, we have a function that performs one training epoch. It
210
# enumerates data from the DataLoader, and on each pass of the loop does
211
# the following:
212
#
213
# - Gets a batch of training data from the DataLoader
214
# - Zeros the optimizer’s gradients
215
# - Performs an inference - that is, gets predictions from the model for an input batch
216
# - Calculates the loss for that set of predictions vs. the labels on the dataset
217
# - Calculates the backward gradients over the learning weights
218
# - Tells the optimizer to perform one learning step - that is, adjust the model’s
219
# learning weights based on the observed gradients for this batch, according to the
220
# optimization algorithm we chose
221
# - It reports on the loss for every 1000 batches.
222
# - Finally, it reports the average per-batch loss for the last
223
# 1000 batches, for comparison with a validation run
224
#
225
226
def train_one_epoch(epoch_index, tb_writer):
227
running_loss = 0.
228
last_loss = 0.
229
230
# Here, we use enumerate(training_loader) instead of
231
# iter(training_loader) so that we can track the batch
232
# index and do some intra-epoch reporting
233
for i, data in enumerate(training_loader):
234
# Every data instance is an input + label pair
235
inputs, labels = data
236
237
# Zero your gradients for every batch!
238
optimizer.zero_grad()
239
240
# Make predictions for this batch
241
outputs = model(inputs)
242
243
# Compute the loss and its gradients
244
loss = loss_fn(outputs, labels)
245
loss.backward()
246
247
# Adjust learning weights
248
optimizer.step()
249
250
# Gather data and report
251
running_loss += loss.item()
252
if i % 1000 == 999:
253
last_loss = running_loss / 1000 # loss per batch
254
print(' batch {} loss: {}'.format(i + 1, last_loss))
255
tb_x = epoch_index * len(training_loader) + i + 1
256
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
257
running_loss = 0.
258
259
return last_loss
260
261
262
##################################################################################
263
# Per-Epoch Activity
264
# ~~~~~~~~~~~~~~~~~~
265
#
266
# There are a couple of things we’ll want to do once per epoch:
267
#
268
# - Perform validation by checking our relative loss on a set of data that was not
269
# used for training, and report this
270
# - Save a copy of the model
271
#
272
# Here, we’ll do our reporting in TensorBoard. This will require going to
273
# the command line to start TensorBoard, and opening it in another browser
274
# tab.
275
#
276
277
# Initializing in a separate cell so we can easily add more epochs to the same run
278
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
279
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
280
epoch_number = 0
281
282
EPOCHS = 5
283
284
best_vloss = 1_000_000.
285
286
for epoch in range(EPOCHS):
287
print('EPOCH {}:'.format(epoch_number + 1))
288
289
# Make sure gradient tracking is on, and do a pass over the data
290
model.train(True)
291
avg_loss = train_one_epoch(epoch_number, writer)
292
293
294
running_vloss = 0.0
295
# Set the model to evaluation mode, disabling dropout and using population
296
# statistics for batch normalization.
297
model.eval()
298
299
# Disable gradient computation and reduce memory consumption.
300
with torch.no_grad():
301
for i, vdata in enumerate(validation_loader):
302
vinputs, vlabels = vdata
303
voutputs = model(vinputs)
304
vloss = loss_fn(voutputs, vlabels)
305
running_vloss += vloss
306
307
avg_vloss = running_vloss / (i + 1)
308
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
309
310
# Log the running loss averaged per batch
311
# for both training and validation
312
writer.add_scalars('Training vs. Validation Loss',
313
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
314
epoch_number + 1)
315
writer.flush()
316
317
# Track best performance, and save the model's state
318
if avg_vloss < best_vloss:
319
best_vloss = avg_vloss
320
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
321
torch.save(model.state_dict(), model_path)
322
323
epoch_number += 1
324
325
326
#########################################################################
327
# To load a saved version of the model:
328
#
329
# .. code:: python
330
#
331
# saved_model = GarmentClassifier()
332
# saved_model.load_state_dict(torch.load(PATH))
333
#
334
# Once you’ve loaded the model, it’s ready for whatever you need it for -
335
# more training, inference, or analysis.
336
#
337
# Note that if your model has constructor parameters that affect model
338
# structure, you’ll need to provide them and configure the model
339
# identically to the state in which it was saved.
340
#
341
# Other Resources
342
# ---------------
343
#
344
# - Docs on the `data
345
# utilities <https://pytorch.org/docs/stable/data.html>`__, including
346
# Dataset and DataLoader, at pytorch.org
347
# - A `note on the use of pinned
348
# memory <https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning>`__
349
# for GPU training
350
# - Documentation on the datasets available in
351
# `TorchVision <https://pytorch.org/vision/stable/datasets.html>`__,
352
# `TorchText <https://pytorch.org/text/stable/datasets.html>`__, and
353
# `TorchAudio <https://pytorch.org/audio/stable/datasets.html>`__
354
# - Documentation on the `loss
355
# functions <https://pytorch.org/docs/stable/nn.html#loss-functions>`__
356
# available in PyTorch
357
# - Documentation on the `torch.optim
358
# package <https://pytorch.org/docs/stable/optim.html>`__, which
359
# includes optimizers and related tools, such as learning rate
360
# scheduling
361
# - A detailed `tutorial on saving and loading
362
# models <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__
363
# - The `Tutorials section of
364
# pytorch.org <https://pytorch.org/tutorials/>`__ contains tutorials on
365
# a broad variety of training tasks, including classification in
366
# different domains, generative adversarial networks, reinforcement
367
# learning, and more
368
#
369
370