Path: blob/main/beginner_source/introyt/trainingyt.py
6815 views
"""1`Introduction <introyt1_tutorial.html>`_ ||2`Tensors <tensors_deeper_tutorial.html>`_ ||3`Autograd <autogradyt_tutorial.html>`_ ||4`Building Models <modelsyt_tutorial.html>`_ ||5`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||6**Training Models** ||7`Model Understanding <captumyt.html>`_89Training with PyTorch10=====================1112Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=jF43_wj_DCQ>`__.1314.. raw:: html1516<div style="margin-top:10px; margin-bottom:10px;">17<iframe width="560" height="315" src="https://www.youtube.com/embed/jF43_wj_DCQ" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>18</div>1920Introduction21------------2223In past videos, we’ve discussed and demonstrated:2425- Building models with the neural network layers and functions of the torch.nn module26- The mechanics of automated gradient computation, which is central to27gradient-based model training28- Using TensorBoard to visualize training progress and other activities2930In this video, we’ll be adding some new tools to your inventory:3132- We’ll get familiar with the dataset and dataloader abstractions, and how33they ease the process of feeding data to your model during a training loop34- We’ll discuss specific loss functions and when to use them35- We’ll look at PyTorch optimizers, which implement algorithms to adjust36model weights based on the outcome of a loss function3738Finally, we’ll pull all of these together and see a full PyTorch39training loop in action.404142Dataset and DataLoader43----------------------4445The ``Dataset`` and ``DataLoader`` classes encapsulate the process of46pulling your data from storage and exposing it to your training loop in47batches.4849The ``Dataset`` is responsible for accessing and processing single50instances of data.5152The ``DataLoader`` pulls instances of data from the ``Dataset`` (either53automatically or with a sampler that you define), collects them in54batches, and returns them for consumption by your training loop. The55``DataLoader`` works with all kinds of datasets, regardless of the type56of data they contain.5758For this tutorial, we’ll be using the Fashion-MNIST dataset provided by59TorchVision. We use ``torchvision.transforms.v2.Normalize()`` to60zero-center and normalize the distribution of the image tile content,61and download both training and validation data splits.6263"""6465import torch66import torchvision67from torchvision.transforms import v26869# PyTorch TensorBoard support70from torch.utils.tensorboard import SummaryWriter71from datetime import datetime727374transform = v2.Compose([75v2.ToImage(),76v2.ToDtype(torch.float32, scale=True),77v2.Normalize((0.5,), (0.5,))78])7980# Create datasets for training & validation, download if necessary81training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)82validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)8384# Create data loaders for our datasets; shuffle for training, not for validation85training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)86validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)8788# Class labels89classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',90'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')9192# Report split sizes93print(f'Training set has {len(training_set)} instances')94print(f'Validation set has {len(validation_set)} instances')959697######################################################################98# As always, let’s visualize the data as a sanity check:99#100101import matplotlib.pyplot as plt102import numpy as np103104# Helper function for inline image display105def matplotlib_imshow(img, one_channel=False):106if one_channel:107img = img.mean(dim=0)108img = img / 2 + 0.5 # unnormalize109npimg = img.numpy()110if one_channel:111plt.imshow(npimg, cmap="Greys")112else:113plt.imshow(np.transpose(npimg, (1, 2, 0)))114115dataiter = iter(training_loader)116images, labels = next(dataiter)117118# Create a grid from the images and show them119img_grid = torchvision.utils.make_grid(images)120matplotlib_imshow(img_grid, one_channel=True)121print(' '.join(classes[labels[j]] for j in range(4)))122123124#########################################################################125# The Model126# ---------127#128# The model we’ll use in this example is a variant of LeNet-5 - it should129# be familiar if you’ve watched the previous videos in this series.130#131132import torch.nn as nn133import torch.nn.functional as F134135# PyTorch models inherit from torch.nn.Module136class GarmentClassifier(nn.Module):137def __init__(self):138super().__init__()139self.conv1 = nn.Conv2d(1, 6, 5)140self.pool = nn.MaxPool2d(2, 2)141self.conv2 = nn.Conv2d(6, 16, 5)142self.fc1 = nn.Linear(16 * 4 * 4, 120)143self.fc2 = nn.Linear(120, 84)144self.fc3 = nn.Linear(84, 10)145146def forward(self, x):147x = self.pool(F.relu(self.conv1(x)))148x = self.pool(F.relu(self.conv2(x)))149x = x.view(-1, 16 * 4 * 4)150x = F.relu(self.fc1(x))151x = F.relu(self.fc2(x))152x = self.fc3(x)153return x154155156model = GarmentClassifier()157158159##########################################################################160# Loss Function161# -------------162#163# For this example, we’ll be using a cross-entropy loss. For demonstration164# purposes, we’ll create batches of dummy output and label values, run165# them through the loss function, and examine the result.166#167168loss_fn = torch.nn.CrossEntropyLoss()169170# NB: Loss functions expect data in batches, so we're creating batches of 4171# Represents the model's confidence in each of the 10 classes for a given input172dummy_outputs = torch.rand(4, 10)173# Represents the correct class among the 10 being tested174dummy_labels = torch.tensor([1, 5, 3, 7])175176print(dummy_outputs)177print(dummy_labels)178179loss = loss_fn(dummy_outputs, dummy_labels)180print(f'Total loss for this batch: {loss.item()}')181182183#################################################################################184# Optimizer185# ---------186#187# For this example, we’ll be using simple `stochastic gradient188# descent <https://pytorch.org/docs/stable/optim.html>`__ with momentum.189#190# It can be instructive to try some variations on this optimization191# scheme:192#193# - Learning rate determines the size of the steps the optimizer194# takes. What does a different learning rate do to the your training195# results, in terms of accuracy and convergence time?196# - Momentum nudges the optimizer in the direction of strongest gradient over197# multiple steps. What does changing this value do to your results?198# - Try some different optimization algorithms, such as averaged SGD, Adagrad, or199# Adam. How do your results differ?200#201202# Optimizers specified in the torch.optim package203optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)204205206#######################################################################################207# The Training Loop208# -----------------209#210# Below, we have a function that performs one training epoch. It211# enumerates data from the DataLoader, and on each pass of the loop does212# the following:213#214# - Gets a batch of training data from the DataLoader215# - Zeros the optimizer’s gradients216# - Performs an inference - that is, gets predictions from the model for an input batch217# - Calculates the loss for that set of predictions vs. the labels on the dataset218# - Calculates the backward gradients over the learning weights219# - Tells the optimizer to perform one learning step - that is, adjust the model’s220# learning weights based on the observed gradients for this batch, according to the221# optimization algorithm we chose222# - It reports on the loss for every 1000 batches.223# - Finally, it reports the average per-batch loss for the last224# 1000 batches, for comparison with a validation run225#226227def train_one_epoch(epoch_index, tb_writer):228running_loss = 0.229last_loss = 0.230231# Here, we use enumerate(training_loader) instead of232# iter(training_loader) so that we can track the batch233# index and do some intra-epoch reporting234for i, data in enumerate(training_loader):235# Every data instance is an input + label pair236inputs, labels = data237238# Zero your gradients for every batch!239optimizer.zero_grad()240241# Make predictions for this batch242outputs = model(inputs)243244# Compute the loss and its gradients245loss = loss_fn(outputs, labels)246loss.backward()247248# Adjust learning weights249optimizer.step()250251# Gather data and report252running_loss += loss.item()253if i % 1000 == 999:254last_loss = running_loss / 1000 # loss per batch255print(f' batch {i + 1} loss: {last_loss}')256tb_x = epoch_index * len(training_loader) + i + 1257tb_writer.add_scalar('Loss/train', last_loss, tb_x)258running_loss = 0.259260return last_loss261262263##################################################################################264# Per-Epoch Activity265# ~~~~~~~~~~~~~~~~~~266#267# There are a couple of things we’ll want to do once per epoch:268#269# - Perform validation by checking our relative loss on a set of data that was not270# used for training, and report this271# - Save a copy of the model272#273# Here, we’ll do our reporting in TensorBoard. This will require going to274# the command line to start TensorBoard, and opening it in another browser275# tab.276#277278# Initializing in a separate cell so we can easily add more epochs to the same run279timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')280writer = SummaryWriter(f'runs/fashion_trainer_{timestamp}')281epoch_number = 0282283EPOCHS = 5284285best_vloss = 1_000_000.286287for epoch in range(EPOCHS):288print(f'EPOCH {epoch_number + 1}:')289290# Make sure gradient tracking is on, and do a pass over the data291model.train(True)292avg_loss = train_one_epoch(epoch_number, writer)293294295running_vloss = 0.0296# Set the model to evaluation mode, disabling dropout and using population297# statistics for batch normalization.298model.eval()299300# Disable gradient computation and reduce memory consumption.301with torch.no_grad():302for i, vdata in enumerate(validation_loader):303vinputs, vlabels = vdata304voutputs = model(vinputs)305vloss = loss_fn(voutputs, vlabels)306running_vloss += vloss307308avg_vloss = running_vloss / (i + 1)309print(f'LOSS train {avg_loss} valid {avg_vloss}')310311# Log the running loss averaged per batch312# for both training and validation313writer.add_scalars('Training vs. Validation Loss',314{ 'Training' : avg_loss, 'Validation' : avg_vloss },315epoch_number + 1)316writer.flush()317318# Track best performance, and save the model's state319if avg_vloss < best_vloss:320best_vloss = avg_vloss321model_path = f'model_{timestamp}_{epoch_number}'322torch.save(model.state_dict(), model_path)323324epoch_number += 1325326327#########################################################################328# To load a saved version of the model:329#330# .. code:: python331#332# saved_model = GarmentClassifier()333# saved_model.load_state_dict(torch.load(PATH))334#335# Once you’ve loaded the model, it’s ready for whatever you need it for -336# more training, inference, or analysis.337#338# Note that if your model has constructor parameters that affect model339# structure, you’ll need to provide them and configure the model340# identically to the state in which it was saved.341#342# Other Resources343# ---------------344#345# - Docs on the `data346# utilities <https://pytorch.org/docs/stable/data.html>`__, including347# Dataset and DataLoader, at pytorch.org348# - A `note on the use of pinned349# memory <https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning>`__350# for GPU training351# - Documentation on the datasets available in352# `TorchVision <https://pytorch.org/vision/stable/datasets.html>`__,353# `TorchText <https://pytorch.org/text/stable/datasets.html>`__, and354# `TorchAudio <https://pytorch.org/audio/stable/datasets.html>`__355# - Documentation on the `loss356# functions <https://pytorch.org/docs/stable/nn.html#loss-functions>`__357# available in PyTorch358# - Documentation on the `torch.optim359# package <https://pytorch.org/docs/stable/optim.html>`__, which360# includes optimizers and related tools, such as learning rate361# scheduling362# - A detailed `tutorial on saving and loading363# models <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__364# - The `Tutorials section of365# pytorch.org <https://pytorch.org/tutorials/>`__ contains tutorials on366# a broad variety of training tasks, including classification in367# different domains, generative adversarial networks, reinforcement368# learning, and more369#370371372