Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/introyt/tensorboardyt_tutorial.py
Views: 713
"""1`Introduction <introyt1_tutorial.html>`_ ||2`Tensors <tensors_deeper_tutorial.html>`_ ||3`Autograd <autogradyt_tutorial.html>`_ ||4`Building Models <modelsyt_tutorial.html>`_ ||5**TensorBoard Support** ||6`Training Models <trainingyt.html>`_ ||7`Model Understanding <captumyt.html>`_89PyTorch TensorBoard Support10===========================1112Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=6CEld3hZgqc>`__.1314.. raw:: html1516<div style="margin-top:10px; margin-bottom:10px;">17<iframe width="560" height="315" src="https://www.youtube.com/embed/6CEld3hZgqc" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>18</div>1920Before You Start21----------------2223To run this tutorial, you’ll need to install PyTorch, TorchVision,24Matplotlib, and TensorBoard.2526With ``conda``:2728.. code-block:: sh2930conda install pytorch torchvision -c pytorch31conda install matplotlib tensorboard3233With ``pip``:3435.. code-block:: sh3637pip install torch torchvision matplotlib tensorboard3839Once the dependencies are installed, restart this notebook in the Python40environment where you installed them.414243Introduction44------------4546In this notebook, we’ll be training a variant of LeNet-5 against the47Fashion-MNIST dataset. Fashion-MNIST is a set of image tiles depicting48various garments, with ten class labels indicating the type of garment49depicted.5051"""5253# PyTorch model and training necessities54import torch55import torch.nn as nn56import torch.nn.functional as F57import torch.optim as optim5859# Image datasets and image manipulation60import torchvision61import torchvision.transforms as transforms6263# Image display64import matplotlib.pyplot as plt65import numpy as np6667# PyTorch TensorBoard support68from torch.utils.tensorboard import SummaryWriter6970# In case you are using an environment that has TensorFlow installed,71# such as Google Colab, uncomment the following code to avoid72# a bug with saving embeddings to your TensorBoard directory7374# import tensorflow as tf75# import tensorboard as tb76# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile7778######################################################################79# Showing Images in TensorBoard80# -----------------------------81#82# Let’s start by adding sample images from our dataset to TensorBoard:83#8485# Gather datasets and prepare them for consumption86transform = transforms.Compose(87[transforms.ToTensor(),88transforms.Normalize((0.5,), (0.5,))])8990# Store separate training and validations splits in ./data91training_set = torchvision.datasets.FashionMNIST('./data',92download=True,93train=True,94transform=transform)95validation_set = torchvision.datasets.FashionMNIST('./data',96download=True,97train=False,98transform=transform)99100training_loader = torch.utils.data.DataLoader(training_set,101batch_size=4,102shuffle=True,103num_workers=2)104105106validation_loader = torch.utils.data.DataLoader(validation_set,107batch_size=4,108shuffle=False,109num_workers=2)110111# Class labels112classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',113'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')114115# Helper function for inline image display116def matplotlib_imshow(img, one_channel=False):117if one_channel:118img = img.mean(dim=0)119img = img / 2 + 0.5 # unnormalize120npimg = img.numpy()121if one_channel:122plt.imshow(npimg, cmap="Greys")123else:124plt.imshow(np.transpose(npimg, (1, 2, 0)))125126# Extract a batch of 4 images127dataiter = iter(training_loader)128images, labels = next(dataiter)129130# Create a grid from the images and show them131img_grid = torchvision.utils.make_grid(images)132matplotlib_imshow(img_grid, one_channel=True)133134135########################################################################136# Above, we used TorchVision and Matplotlib to create a visual grid of a137# minibatch of our input data. Below, we use the ``add_image()`` call on138# ``SummaryWriter`` to log the image for consumption by TensorBoard, and139# we also call ``flush()`` to make sure it’s written to disk right away.140#141142# Default log_dir argument is "runs" - but it's good to be specific143# torch.utils.tensorboard.SummaryWriter is imported above144writer = SummaryWriter('runs/fashion_mnist_experiment_1')145146# Write image data to TensorBoard log dir147writer.add_image('Four Fashion-MNIST Images', img_grid)148writer.flush()149150# To view, start TensorBoard on the command line with:151# tensorboard --logdir=runs152# ...and open a browser tab to http://localhost:6006/153154155##########################################################################156# If you start TensorBoard at the command line and open it in a new157# browser tab (usually at `localhost:6006 <localhost:6006>`__), you should158# see the image grid under the IMAGES tab.159#160# Graphing Scalars to Visualize Training161# --------------------------------------162#163# TensorBoard is useful for tracking the progress and efficacy of your164# training. Below, we’ll run a training loop, track some metrics, and save165# the data for TensorBoard’s consumption.166#167# Let’s define a model to categorize our image tiles, and an optimizer and168# loss function for training:169#170171class Net(nn.Module):172def __init__(self):173super(Net, self).__init__()174self.conv1 = nn.Conv2d(1, 6, 5)175self.pool = nn.MaxPool2d(2, 2)176self.conv2 = nn.Conv2d(6, 16, 5)177self.fc1 = nn.Linear(16 * 4 * 4, 120)178self.fc2 = nn.Linear(120, 84)179self.fc3 = nn.Linear(84, 10)180181def forward(self, x):182x = self.pool(F.relu(self.conv1(x)))183x = self.pool(F.relu(self.conv2(x)))184x = x.view(-1, 16 * 4 * 4)185x = F.relu(self.fc1(x))186x = F.relu(self.fc2(x))187x = self.fc3(x)188return x189190191net = Net()192criterion = nn.CrossEntropyLoss()193optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)194195196##########################################################################197# Now let’s train a single epoch, and evaluate the training vs. validation198# set losses every 1000 batches:199#200201print(len(validation_loader))202for epoch in range(1): # loop over the dataset multiple times203running_loss = 0.0204205for i, data in enumerate(training_loader, 0):206# basic training loop207inputs, labels = data208optimizer.zero_grad()209outputs = net(inputs)210loss = criterion(outputs, labels)211loss.backward()212optimizer.step()213214running_loss += loss.item()215if i % 1000 == 999: # Every 1000 mini-batches...216print('Batch {}'.format(i + 1))217# Check against the validation set218running_vloss = 0.0219220# In evaluation mode some model specific operations can be omitted eg. dropout layer221net.train(False) # Switching to evaluation mode, eg. turning off regularisation222for j, vdata in enumerate(validation_loader, 0):223vinputs, vlabels = vdata224voutputs = net(vinputs)225vloss = criterion(voutputs, vlabels)226running_vloss += vloss.item()227net.train(True) # Switching back to training mode, eg. turning on regularisation228229avg_loss = running_loss / 1000230avg_vloss = running_vloss / len(validation_loader)231232# Log the running loss averaged per batch233writer.add_scalars('Training vs. Validation Loss',234{ 'Training' : avg_loss, 'Validation' : avg_vloss },235epoch * len(training_loader) + i)236237running_loss = 0.0238print('Finished Training')239240writer.flush()241242243#########################################################################244# Switch to your open TensorBoard and have a look at the SCALARS tab.245#246# Visualizing Your Model247# ----------------------248#249# TensorBoard can also be used to examine the data flow within your model.250# To do this, call the ``add_graph()`` method with a model and sample251# input:252#253254# Again, grab a single mini-batch of images255dataiter = iter(training_loader)256images, labels = next(dataiter)257258# add_graph() will trace the sample input through your model,259# and render it as a graph.260writer.add_graph(net, images)261writer.flush()262263264#########################################################################265# When you switch over to TensorBoard, you should see a GRAPHS tab.266# Double-click the “NET” node to see the layers and data flow within your267# model.268#269# Visualizing Your Dataset with Embeddings270# ----------------------------------------271#272# The 28-by-28 image tiles we’re using can be modeled as 784-dimensional273# vectors (28 \* 28 = 784). It can be instructive to project this to a274# lower-dimensional representation. The ``add_embedding()`` method will275# project a set of data onto the three dimensions with highest variance,276# and display them as an interactive 3D chart. The ``add_embedding()``277# method does this automatically by projecting to the three dimensions278# with highest variance.279#280# Below, we’ll take a sample of our data, and generate such an embedding:281#282283# Select a random subset of data and corresponding labels284def select_n_random(data, labels, n=100):285assert len(data) == len(labels)286287perm = torch.randperm(len(data))288return data[perm][:n], labels[perm][:n]289290# Extract a random subset of data291images, labels = select_n_random(training_set.data, training_set.targets)292293# get the class labels for each image294class_labels = [classes[label] for label in labels]295296# log embeddings297features = images.view(-1, 28 * 28)298writer.add_embedding(features,299metadata=class_labels,300label_img=images.unsqueeze(1))301writer.flush()302writer.close()303304305#######################################################################306# Now if you switch to TensorBoard and select the PROJECTOR tab, you307# should see a 3D representation of the projection. You can rotate and308# zoom the model. Examine it at large and small scales, and see whether309# you can spot patterns in the projected data and the clustering of310# labels.311#312# For better visibility, it’s recommended to:313#314# - Select “label” from the “Color by” drop-down on the left.315# - Toggle the Night Mode icon along the top to place the316# light-colored images on a dark background.317#318# Other Resources319# ---------------320#321# For more information, have a look at:322#323# - PyTorch documentation on `torch.utils.tensorboard.SummaryWriter <https://pytorch.org/docs/stable/tensorboard.html?highlight=summarywriter>`__324# - Tensorboard tutorial content in the `PyTorch.org Tutorials <https://pytorch.org/tutorials/>`__325# - For more information about TensorBoard, see the `TensorBoard326# documentation <https://www.tensorflow.org/tensorboard>`__327328329