Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/introyt/trainingyt.py
Views: 713
"""1`Introduction <introyt1_tutorial.html>`_ ||2`Tensors <tensors_deeper_tutorial.html>`_ ||3`Autograd <autogradyt_tutorial.html>`_ ||4`Building Models <modelsyt_tutorial.html>`_ ||5`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||6**Training Models** ||7`Model Understanding <captumyt.html>`_89Training with PyTorch10=====================1112Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=jF43_wj_DCQ>`__.1314.. raw:: html1516<div style="margin-top:10px; margin-bottom:10px;">17<iframe width="560" height="315" src="https://www.youtube.com/embed/jF43_wj_DCQ" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>18</div>1920Introduction21------------2223In past videos, we’ve discussed and demonstrated:2425- Building models with the neural network layers and functions of the torch.nn module26- The mechanics of automated gradient computation, which is central to27gradient-based model training28- Using TensorBoard to visualize training progress and other activities2930In this video, we’ll be adding some new tools to your inventory:3132- We’ll get familiar with the dataset and dataloader abstractions, and how33they ease the process of feeding data to your model during a training loop34- We’ll discuss specific loss functions and when to use them35- We’ll look at PyTorch optimizers, which implement algorithms to adjust36model weights based on the outcome of a loss function3738Finally, we’ll pull all of these together and see a full PyTorch39training loop in action.404142Dataset and DataLoader43----------------------4445The ``Dataset`` and ``DataLoader`` classes encapsulate the process of46pulling your data from storage and exposing it to your training loop in47batches.4849The ``Dataset`` is responsible for accessing and processing single50instances of data.5152The ``DataLoader`` pulls instances of data from the ``Dataset`` (either53automatically or with a sampler that you define), collects them in54batches, and returns them for consumption by your training loop. The55``DataLoader`` works with all kinds of datasets, regardless of the type56of data they contain.5758For this tutorial, we’ll be using the Fashion-MNIST dataset provided by59TorchVision. We use ``torchvision.transforms.Normalize()`` to60zero-center and normalize the distribution of the image tile content,61and download both training and validation data splits.6263"""6465import torch66import torchvision67import torchvision.transforms as transforms6869# PyTorch TensorBoard support70from torch.utils.tensorboard import SummaryWriter71from datetime import datetime727374transform = transforms.Compose(75[transforms.ToTensor(),76transforms.Normalize((0.5,), (0.5,))])7778# Create datasets for training & validation, download if necessary79training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)80validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)8182# Create data loaders for our datasets; shuffle for training, not for validation83training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)84validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)8586# Class labels87classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',88'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')8990# Report split sizes91print('Training set has {} instances'.format(len(training_set)))92print('Validation set has {} instances'.format(len(validation_set)))939495######################################################################96# As always, let’s visualize the data as a sanity check:97#9899import matplotlib.pyplot as plt100import numpy as np101102# Helper function for inline image display103def matplotlib_imshow(img, one_channel=False):104if one_channel:105img = img.mean(dim=0)106img = img / 2 + 0.5 # unnormalize107npimg = img.numpy()108if one_channel:109plt.imshow(npimg, cmap="Greys")110else:111plt.imshow(np.transpose(npimg, (1, 2, 0)))112113dataiter = iter(training_loader)114images, labels = next(dataiter)115116# Create a grid from the images and show them117img_grid = torchvision.utils.make_grid(images)118matplotlib_imshow(img_grid, one_channel=True)119print(' '.join(classes[labels[j]] for j in range(4)))120121122#########################################################################123# The Model124# ---------125#126# The model we’ll use in this example is a variant of LeNet-5 - it should127# be familiar if you’ve watched the previous videos in this series.128#129130import torch.nn as nn131import torch.nn.functional as F132133# PyTorch models inherit from torch.nn.Module134class GarmentClassifier(nn.Module):135def __init__(self):136super(GarmentClassifier, self).__init__()137self.conv1 = nn.Conv2d(1, 6, 5)138self.pool = nn.MaxPool2d(2, 2)139self.conv2 = nn.Conv2d(6, 16, 5)140self.fc1 = nn.Linear(16 * 4 * 4, 120)141self.fc2 = nn.Linear(120, 84)142self.fc3 = nn.Linear(84, 10)143144def forward(self, x):145x = self.pool(F.relu(self.conv1(x)))146x = self.pool(F.relu(self.conv2(x)))147x = x.view(-1, 16 * 4 * 4)148x = F.relu(self.fc1(x))149x = F.relu(self.fc2(x))150x = self.fc3(x)151return x152153154model = GarmentClassifier()155156157##########################################################################158# Loss Function159# -------------160#161# For this example, we’ll be using a cross-entropy loss. For demonstration162# purposes, we’ll create batches of dummy output and label values, run163# them through the loss function, and examine the result.164#165166loss_fn = torch.nn.CrossEntropyLoss()167168# NB: Loss functions expect data in batches, so we're creating batches of 4169# Represents the model's confidence in each of the 10 classes for a given input170dummy_outputs = torch.rand(4, 10)171# Represents the correct class among the 10 being tested172dummy_labels = torch.tensor([1, 5, 3, 7])173174print(dummy_outputs)175print(dummy_labels)176177loss = loss_fn(dummy_outputs, dummy_labels)178print('Total loss for this batch: {}'.format(loss.item()))179180181#################################################################################182# Optimizer183# ---------184#185# For this example, we’ll be using simple `stochastic gradient186# descent <https://pytorch.org/docs/stable/optim.html>`__ with momentum.187#188# It can be instructive to try some variations on this optimization189# scheme:190#191# - Learning rate determines the size of the steps the optimizer192# takes. What does a different learning rate do to the your training193# results, in terms of accuracy and convergence time?194# - Momentum nudges the optimizer in the direction of strongest gradient over195# multiple steps. What does changing this value do to your results?196# - Try some different optimization algorithms, such as averaged SGD, Adagrad, or197# Adam. How do your results differ?198#199200# Optimizers specified in the torch.optim package201optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)202203204#######################################################################################205# The Training Loop206# -----------------207#208# Below, we have a function that performs one training epoch. It209# enumerates data from the DataLoader, and on each pass of the loop does210# the following:211#212# - Gets a batch of training data from the DataLoader213# - Zeros the optimizer’s gradients214# - Performs an inference - that is, gets predictions from the model for an input batch215# - Calculates the loss for that set of predictions vs. the labels on the dataset216# - Calculates the backward gradients over the learning weights217# - Tells the optimizer to perform one learning step - that is, adjust the model’s218# learning weights based on the observed gradients for this batch, according to the219# optimization algorithm we chose220# - It reports on the loss for every 1000 batches.221# - Finally, it reports the average per-batch loss for the last222# 1000 batches, for comparison with a validation run223#224225def train_one_epoch(epoch_index, tb_writer):226running_loss = 0.227last_loss = 0.228229# Here, we use enumerate(training_loader) instead of230# iter(training_loader) so that we can track the batch231# index and do some intra-epoch reporting232for i, data in enumerate(training_loader):233# Every data instance is an input + label pair234inputs, labels = data235236# Zero your gradients for every batch!237optimizer.zero_grad()238239# Make predictions for this batch240outputs = model(inputs)241242# Compute the loss and its gradients243loss = loss_fn(outputs, labels)244loss.backward()245246# Adjust learning weights247optimizer.step()248249# Gather data and report250running_loss += loss.item()251if i % 1000 == 999:252last_loss = running_loss / 1000 # loss per batch253print(' batch {} loss: {}'.format(i + 1, last_loss))254tb_x = epoch_index * len(training_loader) + i + 1255tb_writer.add_scalar('Loss/train', last_loss, tb_x)256running_loss = 0.257258return last_loss259260261##################################################################################262# Per-Epoch Activity263# ~~~~~~~~~~~~~~~~~~264#265# There are a couple of things we’ll want to do once per epoch:266#267# - Perform validation by checking our relative loss on a set of data that was not268# used for training, and report this269# - Save a copy of the model270#271# Here, we’ll do our reporting in TensorBoard. This will require going to272# the command line to start TensorBoard, and opening it in another browser273# tab.274#275276# Initializing in a separate cell so we can easily add more epochs to the same run277timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')278writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))279epoch_number = 0280281EPOCHS = 5282283best_vloss = 1_000_000.284285for epoch in range(EPOCHS):286print('EPOCH {}:'.format(epoch_number + 1))287288# Make sure gradient tracking is on, and do a pass over the data289model.train(True)290avg_loss = train_one_epoch(epoch_number, writer)291292293running_vloss = 0.0294# Set the model to evaluation mode, disabling dropout and using population295# statistics for batch normalization.296model.eval()297298# Disable gradient computation and reduce memory consumption.299with torch.no_grad():300for i, vdata in enumerate(validation_loader):301vinputs, vlabels = vdata302voutputs = model(vinputs)303vloss = loss_fn(voutputs, vlabels)304running_vloss += vloss305306avg_vloss = running_vloss / (i + 1)307print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))308309# Log the running loss averaged per batch310# for both training and validation311writer.add_scalars('Training vs. Validation Loss',312{ 'Training' : avg_loss, 'Validation' : avg_vloss },313epoch_number + 1)314writer.flush()315316# Track best performance, and save the model's state317if avg_vloss < best_vloss:318best_vloss = avg_vloss319model_path = 'model_{}_{}'.format(timestamp, epoch_number)320torch.save(model.state_dict(), model_path)321322epoch_number += 1323324325#########################################################################326# To load a saved version of the model:327#328# .. code:: python329#330# saved_model = GarmentClassifier()331# saved_model.load_state_dict(torch.load(PATH))332#333# Once you’ve loaded the model, it’s ready for whatever you need it for -334# more training, inference, or analysis.335#336# Note that if your model has constructor parameters that affect model337# structure, you’ll need to provide them and configure the model338# identically to the state in which it was saved.339#340# Other Resources341# ---------------342#343# - Docs on the `data344# utilities <https://pytorch.org/docs/stable/data.html>`__, including345# Dataset and DataLoader, at pytorch.org346# - A `note on the use of pinned347# memory <https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning>`__348# for GPU training349# - Documentation on the datasets available in350# `TorchVision <https://pytorch.org/vision/stable/datasets.html>`__,351# `TorchText <https://pytorch.org/text/stable/datasets.html>`__, and352# `TorchAudio <https://pytorch.org/audio/stable/datasets.html>`__353# - Documentation on the `loss354# functions <https://pytorch.org/docs/stable/nn.html#loss-functions>`__355# available in PyTorch356# - Documentation on the `torch.optim357# package <https://pytorch.org/docs/stable/optim.html>`__, which358# includes optimizers and related tools, such as learning rate359# scheduling360# - A detailed `tutorial on saving and loading361# models <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__362# - The `Tutorials section of363# pytorch.org <https://pytorch.org/tutorials/>`__ contains tutorials on364# a broad variety of training tasks, including classification in365# different domains, generative adversarial networks, reinforcement366# learning, and more367#368369370