CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/what_is_state_dict.py
Views: 494
1
"""
2
What is a state_dict in PyTorch
3
===============================
4
In PyTorch, the learnable parameters (i.e. weights and biases) of a
5
``torch.nn.Module`` model are contained in the model’s parameters
6
(accessed with ``model.parameters()``). A ``state_dict`` is simply a
7
Python dictionary object that maps each layer to its parameter tensor.
8
9
Introduction
10
------------
11
A ``state_dict`` is an integral entity if you are interested in saving
12
or loading models from PyTorch.
13
Because ``state_dict`` objects are Python dictionaries, they can be
14
easily saved, updated, altered, and restored, adding a great deal of
15
modularity to PyTorch models and optimizers.
16
Note that only layers with learnable parameters (convolutional layers,
17
linear layers, etc.) and registered buffers (batchnorm’s running_mean)
18
have entries in the model’s ``state_dict``. Optimizer objects
19
(``torch.optim``) also have a ``state_dict``, which contains information
20
about the optimizer’s state, as well as the hyperparameters used.
21
In this recipe, we will see how ``state_dict`` is used with a simple
22
model.
23
24
Setup
25
-----
26
Before we begin, we need to install ``torch`` if it isn’t already
27
available.
28
29
.. code-block:: sh
30
31
pip install torch
32
33
"""
34
35
36
37
######################################################################
38
# Steps
39
# -----
40
#
41
# 1. Import all necessary libraries for loading our data
42
# 2. Define and initialize the neural network
43
# 3. Initialize the optimizer
44
# 4. Access the model and optimizer ``state_dict``
45
#
46
# 1. Import necessary libraries for loading our data
47
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
48
#
49
# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
50
# and ``torch.optim``.
51
#
52
53
import torch
54
import torch.nn as nn
55
import torch.nn.functional as F
56
import torch.optim as optim
57
58
59
######################################################################
60
# 2. Define and initialize the neural network
61
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62
#
63
# For sake of example, we will create a neural network for training
64
# images. To learn more see the Defining a Neural Network recipe.
65
#
66
67
class Net(nn.Module):
68
def __init__(self):
69
super(Net, self).__init__()
70
self.conv1 = nn.Conv2d(3, 6, 5)
71
self.pool = nn.MaxPool2d(2, 2)
72
self.conv2 = nn.Conv2d(6, 16, 5)
73
self.fc1 = nn.Linear(16 * 5 * 5, 120)
74
self.fc2 = nn.Linear(120, 84)
75
self.fc3 = nn.Linear(84, 10)
76
77
def forward(self, x):
78
x = self.pool(F.relu(self.conv1(x)))
79
x = self.pool(F.relu(self.conv2(x)))
80
x = x.view(-1, 16 * 5 * 5)
81
x = F.relu(self.fc1(x))
82
x = F.relu(self.fc2(x))
83
x = self.fc3(x)
84
return x
85
86
net = Net()
87
print(net)
88
89
90
######################################################################
91
# 3. Initialize the optimizer
92
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
93
#
94
# We will use SGD with momentum.
95
#
96
97
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
98
99
100
######################################################################
101
# 4. Access the model and optimizer ``state_dict``
102
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
103
#
104
# Now that we have constructed our model and optimizer, we can understand
105
# what is preserved in their respective ``state_dict`` properties.
106
#
107
108
# Print model's state_dict
109
print("Model's state_dict:")
110
for param_tensor in net.state_dict():
111
print(param_tensor, "\t", net.state_dict()[param_tensor].size())
112
113
print()
114
115
# Print optimizer's state_dict
116
print("Optimizer's state_dict:")
117
for var_name in optimizer.state_dict():
118
print(var_name, "\t", optimizer.state_dict()[var_name])
119
120
121
######################################################################
122
# This information is relevant for saving and loading the model and
123
# optimizers for future use.
124
#
125
# Congratulations! You have successfully used ``state_dict`` in PyTorch.
126
#
127
# Learn More
128
# ----------
129
#
130
# Take a look at these other recipes to continue your learning:
131
#
132
# - `Saving and loading models for inference in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html>`__
133
# - `Saving and loading a general checkpoint in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html>`__
134
135