Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/nn_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2What is `torch.nn` *really*?3============================45**Authors:** Jeremy Howard, `fast.ai <https://www.fast.ai>`_. Thanks to Rachel Thomas and Francisco Ingham.6"""78###############################################################################9# We recommend running this tutorial as a notebook, not a script. To download the notebook (``.ipynb``) file,10# click the link at the top of the page.11#12# PyTorch provides the elegantly designed modules and classes `torch.nn <https://pytorch.org/docs/stable/nn.html>`_ ,13# `torch.optim <https://pytorch.org/docs/stable/optim.html>`_ ,14# `Dataset <https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset>`_ ,15# and `DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`_16# to help you create and train neural networks.17# In order to fully utilize their power and customize18# them for your problem, you need to really understand exactly what they're19# doing. To develop this understanding, we will first train basic neural net20# on the MNIST data set without using any features from these models; we will21# initially only use the most basic PyTorch tensor functionality. Then, we will22# incrementally add one feature from ``torch.nn``, ``torch.optim``, ``Dataset``, or23# ``DataLoader`` at a time, showing exactly what each piece does, and how it24# works to make the code either more concise, or more flexible.25#26# **This tutorial assumes you already have PyTorch installed, and are familiar27# with the basics of tensor operations.** (If you're familiar with Numpy array28# operations, you'll find the PyTorch tensor operations used here nearly identical).29#30# MNIST data setup31# ----------------32#33# We will use the classic `MNIST <http://deeplearning.net/data/mnist/>`_ dataset,34# which consists of black-and-white images of hand-drawn digits (between 0 and 9).35#36# We will use `pathlib <https://docs.python.org/3/library/pathlib.html>`_37# for dealing with paths (part of the Python 3 standard library), and will38# download the dataset using39# `requests <http://docs.python-requests.org/en/master/>`_. We will only40# import modules when we use them, so you can see exactly what's being41# used at each point.4243from pathlib import Path44import requests4546DATA_PATH = Path("data")47PATH = DATA_PATH / "mnist"4849PATH.mkdir(parents=True, exist_ok=True)5051URL = "https://github.com/pytorch/tutorials/raw/main/_static/"52FILENAME = "mnist.pkl.gz"5354if not (PATH / FILENAME).exists():55content = requests.get(URL + FILENAME).content56(PATH / FILENAME).open("wb").write(content)5758###############################################################################59# This dataset is in numpy array format, and has been stored using pickle,60# a python-specific format for serializing data.6162import pickle63import gzip6465with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:66((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")6768###############################################################################69# Each image is 28 x 28, and is being stored as a flattened row of length70# 784 (=28x28). Let's take a look at one; we need to reshape it to 2d71# first.7273from matplotlib import pyplot74import numpy as np7576pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")77# ``pyplot.show()`` only if not on Colab78try:79import google.colab80except ImportError:81pyplot.show()82print(x_train.shape)8384###############################################################################85# PyTorch uses ``torch.tensor``, rather than numpy arrays, so we need to86# convert our data.8788import torch8990x_train, y_train, x_valid, y_valid = map(91torch.tensor, (x_train, y_train, x_valid, y_valid)92)93n, c = x_train.shape94print(x_train, y_train)95print(x_train.shape)96print(y_train.min(), y_train.max())9798###############################################################################99# Neural net from scratch (without ``torch.nn``)100# -----------------------------------------------101#102# Let's first create a model using nothing but PyTorch tensor operations. We're assuming103# you're already familiar with the basics of neural networks. (If you're not, you can104# learn them at `course.fast.ai <https://course.fast.ai>`_).105#106# PyTorch provides methods to create random or zero-filled tensors, which we will107# use to create our weights and bias for a simple linear model. These are just regular108# tensors, with one very special addition: we tell PyTorch that they require a109# gradient. This causes PyTorch to record all of the operations done on the tensor,110# so that it can calculate the gradient during back-propagation *automatically*!111#112# For the weights, we set ``requires_grad`` **after** the initialization, since we113# don't want that step included in the gradient. (Note that a trailing ``_`` in114# PyTorch signifies that the operation is performed in-place.)115#116# .. note:: We are initializing the weights here with117# `Xavier initialisation <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_118# (by multiplying with ``1/sqrt(n)``).119120import math121122weights = torch.randn(784, 10) / math.sqrt(784)123weights.requires_grad_()124bias = torch.zeros(10, requires_grad=True)125126###############################################################################127# Thanks to PyTorch's ability to calculate gradients automatically, we can128# use any standard Python function (or callable object) as a model! So129# let's just write a plain matrix multiplication and broadcasted addition130# to create a simple linear model. We also need an activation function, so131# we'll write `log_softmax` and use it. Remember: although PyTorch132# provides lots of prewritten loss functions, activation functions, and133# so forth, you can easily write your own using plain python. PyTorch will134# even create fast GPU or vectorized CPU code for your function135# automatically.136137def log_softmax(x):138return x - x.exp().sum(-1).log().unsqueeze(-1)139140def model(xb):141return log_softmax(xb @ weights + bias)142143######################################################################################144# In the above, the ``@`` stands for the matrix multiplication operation. We will call145# our function on one batch of data (in this case, 64 images). This is146# one *forward pass*. Note that our predictions won't be any better than147# random at this stage, since we start with random weights.148149bs = 64 # batch size150151xb = x_train[0:bs] # a mini-batch from x152preds = model(xb) # predictions153preds[0], preds.shape154print(preds[0], preds.shape)155156###############################################################################157# As you see, the ``preds`` tensor contains not only the tensor values, but also a158# gradient function. We'll use this later to do backprop.159#160# Let's implement negative log-likelihood to use as the loss function161# (again, we can just use standard Python):162163164def nll(input, target):165return -input[range(target.shape[0]), target].mean()166167loss_func = nll168169###############################################################################170# Let's check our loss with our random model, so we can see if we improve171# after a backprop pass later.172173yb = y_train[0:bs]174print(loss_func(preds, yb))175176177###############################################################################178# Let's also implement a function to calculate the accuracy of our model.179# For each prediction, if the index with the largest value matches the180# target value, then the prediction was correct.181182def accuracy(out, yb):183preds = torch.argmax(out, dim=1)184return (preds == yb).float().mean()185186###############################################################################187# Let's check the accuracy of our random model, so we can see if our188# accuracy improves as our loss improves.189190print(accuracy(preds, yb))191192###############################################################################193# We can now run a training loop. For each iteration, we will:194#195# - select a mini-batch of data (of size ``bs``)196# - use the model to make predictions197# - calculate the loss198# - ``loss.backward()`` updates the gradients of the model, in this case, ``weights``199# and ``bias``.200#201# We now use these gradients to update the weights and bias. We do this202# within the ``torch.no_grad()`` context manager, because we do not want these203# actions to be recorded for our next calculation of the gradient. You can read204# more about how PyTorch's Autograd records operations205# `here <https://pytorch.org/docs/stable/notes/autograd.html>`_.206#207# We then set the208# gradients to zero, so that we are ready for the next loop.209# Otherwise, our gradients would record a running tally of all the operations210# that had happened (i.e. ``loss.backward()`` *adds* the gradients to whatever is211# already stored, rather than replacing them).212#213# .. tip:: You can use the standard python debugger to step through PyTorch214# code, allowing you to check the various variable values at each step.215# Uncomment ``set_trace()`` below to try it out.216#217218from IPython.core.debugger import set_trace219220lr = 0.5 # learning rate221epochs = 2 # how many epochs to train for222223for epoch in range(epochs):224for i in range((n - 1) // bs + 1):225# set_trace()226start_i = i * bs227end_i = start_i + bs228xb = x_train[start_i:end_i]229yb = y_train[start_i:end_i]230pred = model(xb)231loss = loss_func(pred, yb)232233loss.backward()234with torch.no_grad():235weights -= weights.grad * lr236bias -= bias.grad * lr237weights.grad.zero_()238bias.grad.zero_()239240###############################################################################241# That's it: we've created and trained a minimal neural network (in this case, a242# logistic regression, since we have no hidden layers) entirely from scratch!243#244# Let's check the loss and accuracy and compare those to what we got245# earlier. We expect that the loss will have decreased and accuracy to246# have increased, and they have.247248print(loss_func(model(xb), yb), accuracy(model(xb), yb))249250###############################################################################251# Using ``torch.nn.functional``252# ------------------------------253#254# We will now refactor our code, so that it does the same thing as before, only255# we'll start taking advantage of PyTorch's ``nn`` classes to make it more concise256# and flexible. At each step from here, we should be making our code one or more257# of: shorter, more understandable, and/or more flexible.258#259# The first and easiest step is to make our code shorter by replacing our260# hand-written activation and loss functions with those from ``torch.nn.functional``261# (which is generally imported into the namespace ``F`` by convention). This module262# contains all the functions in the ``torch.nn`` library (whereas other parts of the263# library contain classes). As well as a wide range of loss and activation264# functions, you'll also find here some convenient functions for creating neural265# nets, such as pooling functions. (There are also functions for doing convolutions,266# linear layers, etc, but as we'll see, these are usually better handled using267# other parts of the library.)268#269# If you're using negative log likelihood loss and log softmax activation,270# then Pytorch provides a single function ``F.cross_entropy`` that combines271# the two. So we can even remove the activation function from our model.272273import torch.nn.functional as F274275loss_func = F.cross_entropy276277def model(xb):278return xb @ weights + bias279280###############################################################################281# Note that we no longer call ``log_softmax`` in the ``model`` function. Let's282# confirm that our loss and accuracy are the same as before:283284print(loss_func(model(xb), yb), accuracy(model(xb), yb))285286###############################################################################287# Refactor using ``nn.Module``288# -----------------------------289# Next up, we'll use ``nn.Module`` and ``nn.Parameter``, for a clearer and more290# concise training loop. We subclass ``nn.Module`` (which itself is a class and291# able to keep track of state). In this case, we want to create a class that292# holds our weights, bias, and method for the forward step. ``nn.Module`` has a293# number of attributes and methods (such as ``.parameters()`` and ``.zero_grad()``)294# which we will be using.295#296# .. note:: ``nn.Module`` (uppercase M) is a PyTorch specific concept, and is a297# class we'll be using a lot. ``nn.Module`` is not to be confused with the Python298# concept of a (lowercase ``m``) `module <https://docs.python.org/3/tutorial/modules.html>`_,299# which is a file of Python code that can be imported.300301from torch import nn302303class Mnist_Logistic(nn.Module):304def __init__(self):305super().__init__()306self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))307self.bias = nn.Parameter(torch.zeros(10))308309def forward(self, xb):310return xb @ self.weights + self.bias311312###############################################################################313# Since we're now using an object instead of just using a function, we314# first have to instantiate our model:315316model = Mnist_Logistic()317318###############################################################################319# Now we can calculate the loss in the same way as before. Note that320# ``nn.Module`` objects are used as if they are functions (i.e they are321# *callable*), but behind the scenes Pytorch will call our ``forward``322# method automatically.323324print(loss_func(model(xb), yb))325326###############################################################################327# Previously for our training loop we had to update the values for each parameter328# by name, and manually zero out the grads for each parameter separately, like this:329#330# .. code-block:: python331#332# with torch.no_grad():333# weights -= weights.grad * lr334# bias -= bias.grad * lr335# weights.grad.zero_()336# bias.grad.zero_()337#338#339# Now we can take advantage of model.parameters() and model.zero_grad() (which340# are both defined by PyTorch for ``nn.Module``) to make those steps more concise341# and less prone to the error of forgetting some of our parameters, particularly342# if we had a more complicated model:343#344# .. code-block:: python345#346# with torch.no_grad():347# for p in model.parameters(): p -= p.grad * lr348# model.zero_grad()349#350#351# We'll wrap our little training loop in a ``fit`` function so we can run it352# again later.353354def fit():355for epoch in range(epochs):356for i in range((n - 1) // bs + 1):357start_i = i * bs358end_i = start_i + bs359xb = x_train[start_i:end_i]360yb = y_train[start_i:end_i]361pred = model(xb)362loss = loss_func(pred, yb)363364loss.backward()365with torch.no_grad():366for p in model.parameters():367p -= p.grad * lr368model.zero_grad()369370fit()371372###############################################################################373# Let's double-check that our loss has gone down:374375print(loss_func(model(xb), yb))376377###############################################################################378# Refactor using ``nn.Linear``379# ----------------------------380#381# We continue to refactor our code. Instead of manually defining and382# initializing ``self.weights`` and ``self.bias``, and calculating ``xb @383# self.weights + self.bias``, we will instead use the Pytorch class384# `nn.Linear <https://pytorch.org/docs/stable/nn.html#linear-layers>`_ for a385# linear layer, which does all that for us. Pytorch has many types of386# predefined layers that can greatly simplify our code, and often makes it387# faster too.388389class Mnist_Logistic(nn.Module):390def __init__(self):391super().__init__()392self.lin = nn.Linear(784, 10)393394def forward(self, xb):395return self.lin(xb)396397###############################################################################398# We instantiate our model and calculate the loss in the same way as before:399400model = Mnist_Logistic()401print(loss_func(model(xb), yb))402403###############################################################################404# We are still able to use our same ``fit`` method as before.405406fit()407408print(loss_func(model(xb), yb))409410###############################################################################411# Refactor using ``torch.optim``412# ------------------------------413#414# Pytorch also has a package with various optimization algorithms, ``torch.optim``.415# We can use the ``step`` method from our optimizer to take a forward step, instead416# of manually updating each parameter.417#418# This will let us replace our previous manually coded optimization step:419#420# .. code-block:: python421#422# with torch.no_grad():423# for p in model.parameters(): p -= p.grad * lr424# model.zero_grad()425#426# and instead use just:427#428# .. code-block:: python429#430# opt.step()431# opt.zero_grad()432#433# (``optim.zero_grad()`` resets the gradient to 0 and we need to call it before434# computing the gradient for the next minibatch.)435436from torch import optim437438###############################################################################439# We'll define a little function to create our model and optimizer so we440# can reuse it in the future.441442def get_model():443model = Mnist_Logistic()444return model, optim.SGD(model.parameters(), lr=lr)445446model, opt = get_model()447print(loss_func(model(xb), yb))448449for epoch in range(epochs):450for i in range((n - 1) // bs + 1):451start_i = i * bs452end_i = start_i + bs453xb = x_train[start_i:end_i]454yb = y_train[start_i:end_i]455pred = model(xb)456loss = loss_func(pred, yb)457458loss.backward()459opt.step()460opt.zero_grad()461462print(loss_func(model(xb), yb))463464###############################################################################465# Refactor using Dataset466# ------------------------------467#468# PyTorch has an abstract Dataset class. A Dataset can be anything that has469# a ``__len__`` function (called by Python's standard ``len`` function) and470# a ``__getitem__`` function as a way of indexing into it.471# `This tutorial <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>`_472# walks through a nice example of creating a custom ``FacialLandmarkDataset`` class473# as a subclass of ``Dataset``.474#475# PyTorch's `TensorDataset <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#TensorDataset>`_476# is a Dataset wrapping tensors. By defining a length and way of indexing,477# this also gives us a way to iterate, index, and slice along the first478# dimension of a tensor. This will make it easier to access both the479# independent and dependent variables in the same line as we train.480481from torch.utils.data import TensorDataset482483###############################################################################484# Both ``x_train`` and ``y_train`` can be combined in a single ``TensorDataset``,485# which will be easier to iterate over and slice.486487train_ds = TensorDataset(x_train, y_train)488489###############################################################################490# Previously, we had to iterate through minibatches of ``x`` and ``y`` values separately:491#492# .. code-block:: python493#494# xb = x_train[start_i:end_i]495# yb = y_train[start_i:end_i]496#497#498# Now, we can do these two steps together:499#500# .. code-block:: python501#502# xb,yb = train_ds[i*bs : i*bs+bs]503#504505model, opt = get_model()506507for epoch in range(epochs):508for i in range((n - 1) // bs + 1):509xb, yb = train_ds[i * bs: i * bs + bs]510pred = model(xb)511loss = loss_func(pred, yb)512513loss.backward()514opt.step()515opt.zero_grad()516517print(loss_func(model(xb), yb))518519###############################################################################520# Refactor using ``DataLoader``521# ------------------------------522#523# PyTorch's ``DataLoader`` is responsible for managing batches. You can524# create a ``DataLoader`` from any ``Dataset``. ``DataLoader`` makes it easier525# to iterate over batches. Rather than having to use ``train_ds[i*bs : i*bs+bs]``,526# the ``DataLoader`` gives us each minibatch automatically.527528from torch.utils.data import DataLoader529530train_ds = TensorDataset(x_train, y_train)531train_dl = DataLoader(train_ds, batch_size=bs)532533###############################################################################534# Previously, our loop iterated over batches ``(xb, yb)`` like this:535#536# .. code-block:: python537#538# for i in range((n-1)//bs + 1):539# xb,yb = train_ds[i*bs : i*bs+bs]540# pred = model(xb)541#542# Now, our loop is much cleaner, as ``(xb, yb)`` are loaded automatically from the data loader:543#544# .. code-block:: python545#546# for xb,yb in train_dl:547# pred = model(xb)548549model, opt = get_model()550551for epoch in range(epochs):552for xb, yb in train_dl:553pred = model(xb)554loss = loss_func(pred, yb)555556loss.backward()557opt.step()558opt.zero_grad()559560print(loss_func(model(xb), yb))561562###############################################################################563# Thanks to PyTorch's ``nn.Module``, ``nn.Parameter``, ``Dataset``, and ``DataLoader``,564# our training loop is now dramatically smaller and easier to understand. Let's565# now try to add the basic features necessary to create effective models in practice.566#567# Add validation568# -----------------------569#570# In section 1, we were just trying to get a reasonable training loop set up for571# use on our training data. In reality, you **always** should also have572# a `validation set <https://www.fast.ai/2017/11/13/validation-sets/>`_, in order573# to identify if you are overfitting.574#575# Shuffling the training data is576# `important <https://www.quora.com/Does-the-order-of-training-data-matter-when-training-neural-networks>`_577# to prevent correlation between batches and overfitting. On the other hand, the578# validation loss will be identical whether we shuffle the validation set or not.579# Since shuffling takes extra time, it makes no sense to shuffle the validation data.580#581# We'll use a batch size for the validation set that is twice as large as582# that for the training set. This is because the validation set does not583# need backpropagation and thus takes less memory (it doesn't need to584# store the gradients). We take advantage of this to use a larger batch585# size and compute the loss more quickly.586587train_ds = TensorDataset(x_train, y_train)588train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)589590valid_ds = TensorDataset(x_valid, y_valid)591valid_dl = DataLoader(valid_ds, batch_size=bs * 2)592593###############################################################################594# We will calculate and print the validation loss at the end of each epoch.595#596# (Note that we always call ``model.train()`` before training, and ``model.eval()``597# before inference, because these are used by layers such as ``nn.BatchNorm2d``598# and ``nn.Dropout`` to ensure appropriate behavior for these different phases.)599600model, opt = get_model()601602for epoch in range(epochs):603model.train()604for xb, yb in train_dl:605pred = model(xb)606loss = loss_func(pred, yb)607608loss.backward()609opt.step()610opt.zero_grad()611612model.eval()613with torch.no_grad():614valid_loss = sum(loss_func(model(xb), yb) for xb, yb in valid_dl)615616print(epoch, valid_loss / len(valid_dl))617618###############################################################################619# Create fit() and get_data()620# ----------------------------------621#622# We'll now do a little refactoring of our own. Since we go through a similar623# process twice of calculating the loss for both the training set and the624# validation set, let's make that into its own function, ``loss_batch``, which625# computes the loss for one batch.626#627# We pass an optimizer in for the training set, and use it to perform628# backprop. For the validation set, we don't pass an optimizer, so the629# method doesn't perform backprop.630631632def loss_batch(model, loss_func, xb, yb, opt=None):633loss = loss_func(model(xb), yb)634635if opt is not None:636loss.backward()637opt.step()638opt.zero_grad()639640return loss.item(), len(xb)641642###############################################################################643# ``fit`` runs the necessary operations to train our model and compute the644# training and validation losses for each epoch.645646import numpy as np647648def fit(epochs, model, loss_func, opt, train_dl, valid_dl):649for epoch in range(epochs):650model.train()651for xb, yb in train_dl:652loss_batch(model, loss_func, xb, yb, opt)653654model.eval()655with torch.no_grad():656losses, nums = zip(657*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]658)659val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)660661print(epoch, val_loss)662663###############################################################################664# ``get_data`` returns dataloaders for the training and validation sets.665666667def get_data(train_ds, valid_ds, bs):668return (669DataLoader(train_ds, batch_size=bs, shuffle=True),670DataLoader(valid_ds, batch_size=bs * 2),671)672673###############################################################################674# Now, our whole process of obtaining the data loaders and fitting the675# model can be run in 3 lines of code:676677train_dl, valid_dl = get_data(train_ds, valid_ds, bs)678model, opt = get_model()679fit(epochs, model, loss_func, opt, train_dl, valid_dl)680681###############################################################################682# You can use these basic 3 lines of code to train a wide variety of models.683# Let's see if we can use them to train a convolutional neural network (CNN)!684#685# Switch to CNN686# -------------687#688# We are now going to build our neural network with three convolutional layers.689# Because none of the functions in the previous section assume anything about690# the model form, we'll be able to use them to train a CNN without any modification.691#692# We will use PyTorch's predefined693# `Conv2d <https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d>`_ class694# as our convolutional layer. We define a CNN with 3 convolutional layers.695# Each convolution is followed by a ReLU. At the end, we perform an696# average pooling. (Note that ``view`` is PyTorch's version of Numpy's697# ``reshape``)698699class Mnist_CNN(nn.Module):700def __init__(self):701super().__init__()702self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)703self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)704self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)705706def forward(self, xb):707xb = xb.view(-1, 1, 28, 28)708xb = F.relu(self.conv1(xb))709xb = F.relu(self.conv2(xb))710xb = F.relu(self.conv3(xb))711xb = F.avg_pool2d(xb, 4)712return xb.view(-1, xb.size(1))713714lr = 0.1715716###############################################################################717# `Momentum <https://cs231n.github.io/neural-networks-3/#sgd>`_ is a variation on718# stochastic gradient descent that takes previous updates into account as well719# and generally leads to faster training.720721model = Mnist_CNN()722opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)723724fit(epochs, model, loss_func, opt, train_dl, valid_dl)725726###############################################################################727# Using ``nn.Sequential``728# ------------------------729#730# ``torch.nn`` has another handy class we can use to simplify our code:731# `Sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_ .732# A ``Sequential`` object runs each of the modules contained within it, in a733# sequential manner. This is a simpler way of writing our neural network.734#735# To take advantage of this, we need to be able to easily define a736# **custom layer** from a given function. For instance, PyTorch doesn't737# have a `view` layer, and we need to create one for our network. ``Lambda``738# will create a layer that we can then use when defining a network with739# ``Sequential``.740741class Lambda(nn.Module):742def __init__(self, func):743super().__init__()744self.func = func745746def forward(self, x):747return self.func(x)748749750def preprocess(x):751return x.view(-1, 1, 28, 28)752753###############################################################################754# The model created with ``Sequential`` is simple:755756model = nn.Sequential(757Lambda(preprocess),758nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),759nn.ReLU(),760nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),761nn.ReLU(),762nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),763nn.ReLU(),764nn.AvgPool2d(4),765Lambda(lambda x: x.view(x.size(0), -1)),766)767768opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)769770fit(epochs, model, loss_func, opt, train_dl, valid_dl)771772###############################################################################773# Wrapping ``DataLoader``774# -----------------------------775#776# Our CNN is fairly concise, but it only works with MNIST, because:777# - It assumes the input is a 28\*28 long vector778# - It assumes that the final CNN grid size is 4\*4 (since that's the average pooling kernel size we used)779#780# Let's get rid of these two assumptions, so our model works with any 2d781# single channel image. First, we can remove the initial Lambda layer by782# moving the data preprocessing into a generator:783784def preprocess(x, y):785return x.view(-1, 1, 28, 28), y786787788class WrappedDataLoader:789def __init__(self, dl, func):790self.dl = dl791self.func = func792793def __len__(self):794return len(self.dl)795796def __iter__(self):797for b in self.dl:798yield (self.func(*b))799800train_dl, valid_dl = get_data(train_ds, valid_ds, bs)801train_dl = WrappedDataLoader(train_dl, preprocess)802valid_dl = WrappedDataLoader(valid_dl, preprocess)803804###############################################################################805# Next, we can replace ``nn.AvgPool2d`` with ``nn.AdaptiveAvgPool2d``, which806# allows us to define the size of the *output* tensor we want, rather than807# the *input* tensor we have. As a result, our model will work with any808# size input.809810model = nn.Sequential(811nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),812nn.ReLU(),813nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),814nn.ReLU(),815nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),816nn.ReLU(),817nn.AdaptiveAvgPool2d(1),818Lambda(lambda x: x.view(x.size(0), -1)),819)820821opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)822823###############################################################################824# Let's try it out:825826fit(epochs, model, loss_func, opt, train_dl, valid_dl)827828###############################################################################829# Using your GPU830# ---------------831#832# If you're lucky enough to have access to a CUDA-capable GPU (you can833# rent one for about $0.50/hour from most cloud providers) you can834# use it to speed up your code. First check that your GPU is working in835# Pytorch:836837print(torch.cuda.is_available())838839###############################################################################840# And then create a device object for it:841842dev = torch.device(843"cuda") if torch.cuda.is_available() else torch.device("cpu")844845###############################################################################846# Let's update ``preprocess`` to move batches to the GPU:847848849def preprocess(x, y):850return x.view(-1, 1, 28, 28).to(dev), y.to(dev)851852853train_dl, valid_dl = get_data(train_ds, valid_ds, bs)854train_dl = WrappedDataLoader(train_dl, preprocess)855valid_dl = WrappedDataLoader(valid_dl, preprocess)856857###############################################################################858# Finally, we can move our model to the GPU.859860model.to(dev)861opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)862863###############################################################################864# You should find it runs faster now:865866fit(epochs, model, loss_func, opt, train_dl, valid_dl)867868###############################################################################869# Closing thoughts870# -----------------871#872# We now have a general data pipeline and training loop which you can use for873# training many types of models using Pytorch. To see how simple training a model874# can now be, take a look at the `mnist_sample notebook <https://github.com/fastai/fastai_dev/blob/master/dev_nb/mnist_sample.ipynb>`__.875#876# Of course, there are many things you'll want to add, such as data augmentation,877# hyperparameter tuning, monitoring training, transfer learning, and so forth.878# These features are available in the fastai library, which has been developed879# using the same design approach shown in this tutorial, providing a natural880# next step for practitioners looking to take their models further.881#882# We promised at the start of this tutorial we'd explain through example each of883# ``torch.nn``, ``torch.optim``, ``Dataset``, and ``DataLoader``. So let's summarize884# what we've seen:885#886# - ``torch.nn``:887#888# + ``Module``: creates a callable which behaves like a function, but can also889# contain state(such as neural net layer weights). It knows what ``Parameter`` (s) it890# contains and can zero all their gradients, loop through them for weight updates, etc.891# + ``Parameter``: a wrapper for a tensor that tells a ``Module`` that it has weights892# that need updating during backprop. Only tensors with the `requires_grad` attribute set are updated893# + ``functional``: a module(usually imported into the ``F`` namespace by convention)894# which contains activation functions, loss functions, etc, as well as non-stateful895# versions of layers such as convolutional and linear layers.896# - ``torch.optim``: Contains optimizers such as ``SGD``, which update the weights897# of ``Parameter`` during the backward step898# - ``Dataset``: An abstract interface of objects with a ``__len__`` and a ``__getitem__``,899# including classes provided with Pytorch such as ``TensorDataset``900# - ``DataLoader``: Takes any ``Dataset`` and creates an iterator which returns batches of data.901902903