Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/what_is_state_dict.py
Views: 713
"""1What is a state_dict in PyTorch2===============================3In PyTorch, the learnable parameters (i.e. weights and biases) of a4``torch.nn.Module`` model are contained in the model’s parameters5(accessed with ``model.parameters()``). A ``state_dict`` is simply a6Python dictionary object that maps each layer to its parameter tensor.78Introduction9------------10A ``state_dict`` is an integral entity if you are interested in saving11or loading models from PyTorch.12Because ``state_dict`` objects are Python dictionaries, they can be13easily saved, updated, altered, and restored, adding a great deal of14modularity to PyTorch models and optimizers.15Note that only layers with learnable parameters (convolutional layers,16linear layers, etc.) and registered buffers (batchnorm’s running_mean)17have entries in the model’s ``state_dict``. Optimizer objects18(``torch.optim``) also have a ``state_dict``, which contains information19about the optimizer’s state, as well as the hyperparameters used.20In this recipe, we will see how ``state_dict`` is used with a simple21model.2223Setup24-----25Before we begin, we need to install ``torch`` if it isn’t already26available.2728.. code-block:: sh2930pip install torch3132"""33343536######################################################################37# Steps38# -----39#40# 1. Import all necessary libraries for loading our data41# 2. Define and initialize the neural network42# 3. Initialize the optimizer43# 4. Access the model and optimizer ``state_dict``44#45# 1. Import necessary libraries for loading our data46# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~47#48# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``49# and ``torch.optim``.50#5152import torch53import torch.nn as nn54import torch.nn.functional as F55import torch.optim as optim565758######################################################################59# 2. Define and initialize the neural network60# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~61#62# For sake of example, we will create a neural network for training63# images. To learn more see the Defining a Neural Network recipe.64#6566class Net(nn.Module):67def __init__(self):68super(Net, self).__init__()69self.conv1 = nn.Conv2d(3, 6, 5)70self.pool = nn.MaxPool2d(2, 2)71self.conv2 = nn.Conv2d(6, 16, 5)72self.fc1 = nn.Linear(16 * 5 * 5, 120)73self.fc2 = nn.Linear(120, 84)74self.fc3 = nn.Linear(84, 10)7576def forward(self, x):77x = self.pool(F.relu(self.conv1(x)))78x = self.pool(F.relu(self.conv2(x)))79x = x.view(-1, 16 * 5 * 5)80x = F.relu(self.fc1(x))81x = F.relu(self.fc2(x))82x = self.fc3(x)83return x8485net = Net()86print(net)878889######################################################################90# 3. Initialize the optimizer91# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~92#93# We will use SGD with momentum.94#9596optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)979899######################################################################100# 4. Access the model and optimizer ``state_dict``101# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~102#103# Now that we have constructed our model and optimizer, we can understand104# what is preserved in their respective ``state_dict`` properties.105#106107# Print model's state_dict108print("Model's state_dict:")109for param_tensor in net.state_dict():110print(param_tensor, "\t", net.state_dict()[param_tensor].size())111112print()113114# Print optimizer's state_dict115print("Optimizer's state_dict:")116for var_name in optimizer.state_dict():117print(var_name, "\t", optimizer.state_dict()[var_name])118119120######################################################################121# This information is relevant for saving and loading the model and122# optimizers for future use.123#124# Congratulations! You have successfully used ``state_dict`` in PyTorch.125#126# Learn More127# ----------128#129# Take a look at these other recipes to continue your learning:130#131# - `Saving and loading models for inference in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html>`__132# - `Saving and loading a general checkpoint in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html>`__133134135