CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/introyt/introyt1_tutorial.py
Views: 494
1
"""
2
**Introduction** ||
3
`Tensors <tensors_deeper_tutorial.html>`_ ||
4
`Autograd <autogradyt_tutorial.html>`_ ||
5
`Building Models <modelsyt_tutorial.html>`_ ||
6
`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||
7
`Training Models <trainingyt.html>`_ ||
8
`Model Understanding <captumyt.html>`_
9
10
Introduction to PyTorch
11
=======================
12
13
Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=IC0_FRiX-sw>`__.
14
15
.. raw:: html
16
17
<div style="margin-top:10px; margin-bottom:10px;">
18
<iframe width="560" height="315" src="https://www.youtube.com/embed/IC0_FRiX-sw" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
19
</div>
20
21
PyTorch Tensors
22
---------------
23
24
Follow along with the video beginning at `03:50 <https://www.youtube.com/watch?v=IC0_FRiX-sw&t=230s>`__.
25
26
First, we’ll import pytorch.
27
28
"""
29
30
import torch
31
32
######################################################################
33
# Let’s see a few basic tensor manipulations. First, just a few of the
34
# ways to create tensors:
35
#
36
37
z = torch.zeros(5, 3)
38
print(z)
39
print(z.dtype)
40
41
42
#########################################################################
43
# Above, we create a 5x3 matrix filled with zeros, and query its datatype
44
# to find out that the zeros are 32-bit floating point numbers, which is
45
# the default PyTorch.
46
#
47
# What if you wanted integers instead? You can always override the
48
# default:
49
#
50
51
i = torch.ones((5, 3), dtype=torch.int16)
52
print(i)
53
54
55
######################################################################
56
# You can see that when we do change the default, the tensor helpfully
57
# reports this when printed.
58
#
59
# It’s common to initialize learning weights randomly, often with a
60
# specific seed for the PRNG for reproducibility of results:
61
#
62
63
torch.manual_seed(1729)
64
r1 = torch.rand(2, 2)
65
print('A random tensor:')
66
print(r1)
67
68
r2 = torch.rand(2, 2)
69
print('\nA different random tensor:')
70
print(r2) # new values
71
72
torch.manual_seed(1729)
73
r3 = torch.rand(2, 2)
74
print('\nShould match r1:')
75
print(r3) # repeats values of r1 because of re-seed
76
77
78
#######################################################################
79
# PyTorch tensors perform arithmetic operations intuitively. Tensors of
80
# similar shapes may be added, multiplied, etc. Operations with scalars
81
# are distributed over the tensor:
82
#
83
84
ones = torch.ones(2, 3)
85
print(ones)
86
87
twos = torch.ones(2, 3) * 2 # every element is multiplied by 2
88
print(twos)
89
90
threes = ones + twos # addition allowed because shapes are similar
91
print(threes) # tensors are added element-wise
92
print(threes.shape) # this has the same dimensions as input tensors
93
94
r1 = torch.rand(2, 3)
95
r2 = torch.rand(3, 2)
96
# uncomment this line to get a runtime error
97
# r3 = r1 + r2
98
99
100
######################################################################
101
# Here’s a small sample of the mathematical operations available:
102
#
103
104
r = (torch.rand(2, 2) - 0.5) * 2 # values between -1 and 1
105
print('A random matrix, r:')
106
print(r)
107
108
# Common mathematical operations are supported:
109
print('\nAbsolute value of r:')
110
print(torch.abs(r))
111
112
# ...as are trigonometric functions:
113
print('\nInverse sine of r:')
114
print(torch.asin(r))
115
116
# ...and linear algebra operations like determinant and singular value decomposition
117
print('\nDeterminant of r:')
118
print(torch.det(r))
119
print('\nSingular value decomposition of r:')
120
print(torch.svd(r))
121
122
# ...and statistical and aggregate operations:
123
print('\nAverage and standard deviation of r:')
124
print(torch.std_mean(r))
125
print('\nMaximum value of r:')
126
print(torch.max(r))
127
128
129
##########################################################################
130
# There’s a good deal more to know about the power of PyTorch tensors,
131
# including how to set them up for parallel computations on GPU - we’ll be
132
# going into more depth in another video.
133
#
134
# PyTorch Models
135
# --------------
136
#
137
# Follow along with the video beginning at `10:00 <https://www.youtube.com/watch?v=IC0_FRiX-sw&t=600s>`__.
138
#
139
# Let’s talk about how we can express models in PyTorch
140
#
141
142
import torch # for all things PyTorch
143
import torch.nn as nn # for torch.nn.Module, the parent object for PyTorch models
144
import torch.nn.functional as F # for the activation function
145
146
147
#########################################################################
148
# .. figure:: /_static/img/mnist.png
149
# :alt: le-net-5 diagram
150
#
151
# *Figure: LeNet-5*
152
#
153
# Above is a diagram of LeNet-5, one of the earliest convolutional neural
154
# nets, and one of the drivers of the explosion in Deep Learning. It was
155
# built to read small images of handwritten numbers (the MNIST dataset),
156
# and correctly classify which digit was represented in the image.
157
#
158
# Here’s the abridged version of how it works:
159
#
160
# - Layer C1 is a convolutional layer, meaning that it scans the input
161
# image for features it learned during training. It outputs a map of
162
# where it saw each of its learned features in the image. This
163
# “activation map” is downsampled in layer S2.
164
# - Layer C3 is another convolutional layer, this time scanning C1’s
165
# activation map for *combinations* of features. It also puts out an
166
# activation map describing the spatial locations of these feature
167
# combinations, which is downsampled in layer S4.
168
# - Finally, the fully-connected layers at the end, F5, F6, and OUTPUT,
169
# are a *classifier* that takes the final activation map, and
170
# classifies it into one of ten bins representing the 10 digits.
171
#
172
# How do we express this simple neural network in code?
173
#
174
175
class LeNet(nn.Module):
176
177
def __init__(self):
178
super(LeNet, self).__init__()
179
# 1 input image channel (black & white), 6 output channels, 5x5 square convolution
180
# kernel
181
self.conv1 = nn.Conv2d(1, 6, 5)
182
self.conv2 = nn.Conv2d(6, 16, 5)
183
# an affine operation: y = Wx + b
184
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
185
self.fc2 = nn.Linear(120, 84)
186
self.fc3 = nn.Linear(84, 10)
187
188
def forward(self, x):
189
# Max pooling over a (2, 2) window
190
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
191
# If the size is a square you can only specify a single number
192
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
193
x = x.view(-1, self.num_flat_features(x))
194
x = F.relu(self.fc1(x))
195
x = F.relu(self.fc2(x))
196
x = self.fc3(x)
197
return x
198
199
def num_flat_features(self, x):
200
size = x.size()[1:] # all dimensions except the batch dimension
201
num_features = 1
202
for s in size:
203
num_features *= s
204
return num_features
205
206
207
############################################################################
208
# Looking over this code, you should be able to spot some structural
209
# similarities with the diagram above.
210
#
211
# This demonstrates the structure of a typical PyTorch model:
212
#
213
# - It inherits from ``torch.nn.Module`` - modules may be nested - in fact,
214
# even the ``Conv2d`` and ``Linear`` layer classes inherit from
215
# ``torch.nn.Module``.
216
# - A model will have an ``__init__()`` function, where it instantiates
217
# its layers, and loads any data artifacts it might
218
# need (e.g., an NLP model might load a vocabulary).
219
# - A model will have a ``forward()`` function. This is where the actual
220
# computation happens: An input is passed through the network layers
221
# and various functions to generate an output.
222
# - Other than that, you can build out your model class like any other
223
# Python class, adding whatever properties and methods you need to
224
# support your model’s computation.
225
#
226
# Let’s instantiate this object and run a sample input through it.
227
#
228
229
net = LeNet()
230
print(net) # what does the object tell us about itself?
231
232
input = torch.rand(1, 1, 32, 32) # stand-in for a 32x32 black & white image
233
print('\nImage batch shape:')
234
print(input.shape)
235
236
output = net(input) # we don't call forward() directly
237
print('\nRaw output:')
238
print(output)
239
print(output.shape)
240
241
242
##########################################################################
243
# There are a few important things happening above:
244
#
245
# First, we instantiate the ``LeNet`` class, and we print the ``net``
246
# object. A subclass of ``torch.nn.Module`` will report the layers it has
247
# created and their shapes and parameters. This can provide a handy
248
# overview of a model if you want to get the gist of its processing.
249
#
250
# Below that, we create a dummy input representing a 32x32 image with 1
251
# color channel. Normally, you would load an image tile and convert it to
252
# a tensor of this shape.
253
#
254
# You may have noticed an extra dimension to our tensor - the *batch
255
# dimension.* PyTorch models assume they are working on *batches* of data
256
# - for example, a batch of 16 of our image tiles would have the shape
257
# ``(16, 1, 32, 32)``. Since we’re only using one image, we create a batch
258
# of 1 with shape ``(1, 1, 32, 32)``.
259
#
260
# We ask the model for an inference by calling it like a function:
261
# ``net(input)``. The output of this call represents the model’s
262
# confidence that the input represents a particular digit. (Since this
263
# instance of the model hasn’t learned anything yet, we shouldn’t expect
264
# to see any signal in the output.) Looking at the shape of ``output``, we
265
# can see that it also has a batch dimension, the size of which should
266
# always match the input batch dimension. If we had passed in an input
267
# batch of 16 instances, ``output`` would have a shape of ``(16, 10)``.
268
#
269
# Datasets and Dataloaders
270
# ------------------------
271
#
272
# Follow along with the video beginning at `14:00 <https://www.youtube.com/watch?v=IC0_FRiX-sw&t=840s>`__.
273
#
274
# Below, we’re going to demonstrate using one of the ready-to-download,
275
# open-access datasets from TorchVision, how to transform the images for
276
# consumption by your model, and how to use the DataLoader to feed batches
277
# of data to your model.
278
#
279
# The first thing we need to do is transform our incoming images into a
280
# PyTorch tensor.
281
#
282
283
#%matplotlib inline
284
285
import torch
286
import torchvision
287
import torchvision.transforms as transforms
288
289
transform = transforms.Compose(
290
[transforms.ToTensor(),
291
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
292
293
294
##########################################################################
295
# Here, we specify two transformations for our input:
296
#
297
# - ``transforms.ToTensor()`` converts images loaded by Pillow into
298
# PyTorch tensors.
299
# - ``transforms.Normalize()`` adjusts the values of the tensor so
300
# that their average is zero and their standard deviation is 1.0. Most
301
# activation functions have their strongest gradients around x = 0, so
302
# centering our data there can speed learning.
303
# The values passed to the transform are the means (first tuple) and the
304
# standard deviations (second tuple) of the rgb values of the images in
305
# the dataset. You can calculate these values yourself by running these
306
# few lines of code:
307
# ```
308
# from torch.utils.data import ConcatDataset
309
# transform = transforms.Compose([transforms.ToTensor()])
310
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
311
# download=True, transform=transform)
312
#
313
# #stack all train images together into a tensor of shape
314
# #(50000, 3, 32, 32)
315
# x = torch.stack([sample[0] for sample in ConcatDataset([trainset])])
316
#
317
# #get the mean of each channel
318
# mean = torch.mean(x, dim=(0,2,3)) #tensor([0.4914, 0.4822, 0.4465])
319
# std = torch.std(x, dim=(0,2,3)) #tensor([0.2470, 0.2435, 0.2616])
320
#
321
# ```
322
#
323
# There are many more transforms available, including cropping, centering,
324
# rotation, and reflection.
325
#
326
# Next, we’ll create an instance of the CIFAR10 dataset. This is a set of
327
# 32x32 color image tiles representing 10 classes of objects: 6 of animals
328
# (bird, cat, deer, dog, frog, horse) and 4 of vehicles (airplane,
329
# automobile, ship, truck):
330
#
331
332
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
333
download=True, transform=transform)
334
335
336
##########################################################################
337
# .. note::
338
# When you run the cell above, it may take a little time for the
339
# dataset to download.
340
#
341
# This is an example of creating a dataset object in PyTorch. Downloadable
342
# datasets (like CIFAR-10 above) are subclasses of
343
# ``torch.utils.data.Dataset``. ``Dataset`` classes in PyTorch include the
344
# downloadable datasets in TorchVision, Torchtext, and TorchAudio, as well
345
# as utility dataset classes such as ``torchvision.datasets.ImageFolder``,
346
# which will read a folder of labeled images. You can also create your own
347
# subclasses of ``Dataset``.
348
#
349
# When we instantiate our dataset, we need to tell it a few things:
350
#
351
# - The filesystem path to where we want the data to go.
352
# - Whether or not we are using this set for training; most datasets
353
# will be split into training and test subsets.
354
# - Whether we would like to download the dataset if we haven’t already.
355
# - The transformations we want to apply to the data.
356
#
357
# Once your dataset is ready, you can give it to the ``DataLoader``:
358
#
359
360
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
361
shuffle=True, num_workers=2)
362
363
364
##########################################################################
365
# A ``Dataset`` subclass wraps access to the data, and is specialized to
366
# the type of data it’s serving. The ``DataLoader`` knows *nothing* about
367
# the data, but organizes the input tensors served by the ``Dataset`` into
368
# batches with the parameters you specify.
369
#
370
# In the example above, we’ve asked a ``DataLoader`` to give us batches of
371
# 4 images from ``trainset``, randomizing their order (``shuffle=True``),
372
# and we told it to spin up two workers to load data from disk.
373
#
374
# It’s good practice to visualize the batches your ``DataLoader`` serves:
375
#
376
377
import matplotlib.pyplot as plt
378
import numpy as np
379
380
classes = ('plane', 'car', 'bird', 'cat',
381
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
382
383
def imshow(img):
384
img = img / 2 + 0.5 # unnormalize
385
npimg = img.numpy()
386
plt.imshow(np.transpose(npimg, (1, 2, 0)))
387
388
389
# get some random training images
390
dataiter = iter(trainloader)
391
images, labels = next(dataiter)
392
393
# show images
394
imshow(torchvision.utils.make_grid(images))
395
# print labels
396
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
397
398
399
########################################################################
400
# Running the above cell should show you a strip of four images, and the
401
# correct label for each.
402
#
403
# Training Your PyTorch Model
404
# ---------------------------
405
#
406
# Follow along with the video beginning at `17:10 <https://www.youtube.com/watch?v=IC0_FRiX-sw&t=1030s>`__.
407
#
408
# Let’s put all the pieces together, and train a model:
409
#
410
411
#%matplotlib inline
412
413
import torch
414
import torch.nn as nn
415
import torch.nn.functional as F
416
import torch.optim as optim
417
418
import torchvision
419
import torchvision.transforms as transforms
420
421
import matplotlib
422
import matplotlib.pyplot as plt
423
import numpy as np
424
425
426
#########################################################################
427
# First, we’ll need training and test datasets. If you haven’t already,
428
# run the cell below to make sure the dataset is downloaded. (It may take
429
# a minute.)
430
#
431
432
transform = transforms.Compose(
433
[transforms.ToTensor(),
434
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
435
436
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
437
download=True, transform=transform)
438
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
439
shuffle=True, num_workers=2)
440
441
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
442
download=True, transform=transform)
443
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
444
shuffle=False, num_workers=2)
445
446
classes = ('plane', 'car', 'bird', 'cat',
447
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
448
449
450
######################################################################
451
# We’ll run our check on the output from ``DataLoader``:
452
#
453
454
import matplotlib.pyplot as plt
455
import numpy as np
456
457
# functions to show an image
458
459
460
def imshow(img):
461
img = img / 2 + 0.5 # unnormalize
462
npimg = img.numpy()
463
plt.imshow(np.transpose(npimg, (1, 2, 0)))
464
465
466
# get some random training images
467
dataiter = iter(trainloader)
468
images, labels = next(dataiter)
469
470
# show images
471
imshow(torchvision.utils.make_grid(images))
472
# print labels
473
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
474
475
476
##########################################################################
477
# This is the model we’ll train. If it looks familiar, that’s because it’s
478
# a variant of LeNet - discussed earlier in this video - adapted for
479
# 3-color images.
480
#
481
482
class Net(nn.Module):
483
def __init__(self):
484
super(Net, self).__init__()
485
self.conv1 = nn.Conv2d(3, 6, 5)
486
self.pool = nn.MaxPool2d(2, 2)
487
self.conv2 = nn.Conv2d(6, 16, 5)
488
self.fc1 = nn.Linear(16 * 5 * 5, 120)
489
self.fc2 = nn.Linear(120, 84)
490
self.fc3 = nn.Linear(84, 10)
491
492
def forward(self, x):
493
x = self.pool(F.relu(self.conv1(x)))
494
x = self.pool(F.relu(self.conv2(x)))
495
x = x.view(-1, 16 * 5 * 5)
496
x = F.relu(self.fc1(x))
497
x = F.relu(self.fc2(x))
498
x = self.fc3(x)
499
return x
500
501
502
net = Net()
503
504
505
######################################################################
506
# The last ingredients we need are a loss function and an optimizer:
507
#
508
509
criterion = nn.CrossEntropyLoss()
510
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
511
512
513
##########################################################################
514
# The loss function, as discussed earlier in this video, is a measure of
515
# how far from our ideal output the model’s prediction was. Cross-entropy
516
# loss is a typical loss function for classification models like ours.
517
#
518
# The **optimizer** is what drives the learning. Here we have created an
519
# optimizer that implements *stochastic gradient descent,* one of the more
520
# straightforward optimization algorithms. Besides parameters of the
521
# algorithm, like the learning rate (``lr``) and momentum, we also pass in
522
# ``net.parameters()``, which is a collection of all the learning weights
523
# in the model - which is what the optimizer adjusts.
524
#
525
# Finally, all of this is assembled into the training loop. Go ahead and
526
# run this cell, as it will likely take a few minutes to execute:
527
#
528
529
for epoch in range(2): # loop over the dataset multiple times
530
531
running_loss = 0.0
532
for i, data in enumerate(trainloader, 0):
533
# get the inputs
534
inputs, labels = data
535
536
# zero the parameter gradients
537
optimizer.zero_grad()
538
539
# forward + backward + optimize
540
outputs = net(inputs)
541
loss = criterion(outputs, labels)
542
loss.backward()
543
optimizer.step()
544
545
# print statistics
546
running_loss += loss.item()
547
if i % 2000 == 1999: # print every 2000 mini-batches
548
print('[%d, %5d] loss: %.3f' %
549
(epoch + 1, i + 1, running_loss / 2000))
550
running_loss = 0.0
551
552
print('Finished Training')
553
554
555
########################################################################
556
# Here, we are doing only **2 training epochs** (line 1) - that is, two
557
# passes over the training dataset. Each pass has an inner loop that
558
# **iterates over the training data** (line 4), serving batches of
559
# transformed input images and their correct labels.
560
#
561
# **Zeroing the gradients** (line 9) is an important step. Gradients are
562
# accumulated over a batch; if we do not reset them for every batch, they
563
# will keep accumulating, which will provide incorrect gradient values,
564
# making learning impossible.
565
#
566
# In line 12, we **ask the model for its predictions** on this batch. In
567
# the following line (13), we compute the loss - the difference between
568
# ``outputs`` (the model prediction) and ``labels`` (the correct output).
569
#
570
# In line 14, we do the ``backward()`` pass, and calculate the gradients
571
# that will direct the learning.
572
#
573
# In line 15, the optimizer performs one learning step - it uses the
574
# gradients from the ``backward()`` call to nudge the learning weights in
575
# the direction it thinks will reduce the loss.
576
#
577
# The remainder of the loop does some light reporting on the epoch number,
578
# how many training instances have been completed, and what the collected
579
# loss is over the training loop.
580
#
581
# **When you run the cell above,** you should see something like this:
582
#
583
# .. code-block:: sh
584
#
585
# [1, 2000] loss: 2.235
586
# [1, 4000] loss: 1.940
587
# [1, 6000] loss: 1.713
588
# [1, 8000] loss: 1.573
589
# [1, 10000] loss: 1.507
590
# [1, 12000] loss: 1.442
591
# [2, 2000] loss: 1.378
592
# [2, 4000] loss: 1.364
593
# [2, 6000] loss: 1.349
594
# [2, 8000] loss: 1.319
595
# [2, 10000] loss: 1.284
596
# [2, 12000] loss: 1.267
597
# Finished Training
598
#
599
# Note that the loss is monotonically descending, indicating that our
600
# model is continuing to improve its performance on the training dataset.
601
#
602
# As a final step, we should check that the model is actually doing
603
# *general* learning, and not simply “memorizing” the dataset. This is
604
# called **overfitting,** and usually indicates that the dataset is too
605
# small (not enough examples for general learning), or that the model has
606
# more learning parameters than it needs to correctly model the dataset.
607
#
608
# This is the reason datasets are split into training and test subsets -
609
# to test the generality of the model, we ask it to make predictions on
610
# data it hasn’t trained on:
611
#
612
613
correct = 0
614
total = 0
615
with torch.no_grad():
616
for data in testloader:
617
images, labels = data
618
outputs = net(images)
619
_, predicted = torch.max(outputs.data, 1)
620
total += labels.size(0)
621
correct += (predicted == labels).sum().item()
622
623
print('Accuracy of the network on the 10000 test images: %d %%' % (
624
100 * correct / total))
625
626
627
#########################################################################
628
# If you followed along, you should see that the model is roughly 50%
629
# accurate at this point. That’s not exactly state-of-the-art, but it’s
630
# far better than the 10% accuracy we’d expect from a random output. This
631
# demonstrates that some general learning did happen in the model.
632
#
633
634