CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/saving_loading_models.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Saving and Loading Models
4
=========================
5
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
6
7
This document provides solutions to a variety of use cases regarding the
8
saving and loading of PyTorch models. Feel free to read the whole
9
document, or just skip to the code you need for a desired use case.
10
11
When it comes to saving and loading models, there are three core
12
functions to be familiar with:
13
14
1) `torch.save <https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save>`__:
15
Saves a serialized object to disk. This function uses Python’s
16
`pickle <https://docs.python.org/3/library/pickle.html>`__ utility
17
for serialization. Models, tensors, and dictionaries of all kinds of
18
objects can be saved using this function.
19
20
2) `torch.load <https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load>`__:
21
Uses `pickle <https://docs.python.org/3/library/pickle.html>`__\ ’s
22
unpickling facilities to deserialize pickled object files to memory.
23
This function also facilitates the device to load the data into (see
24
`Saving & Loading Model Across
25
Devices <#saving-loading-model-across-devices>`__).
26
27
3) `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict>`__:
28
Loads a model’s parameter dictionary using a deserialized
29
*state_dict*. For more information on *state_dict*, see `What is a
30
state_dict? <#what-is-a-state-dict>`__.
31
32
33
34
**Contents:**
35
36
- `What is a state_dict? <#what-is-a-state-dict>`__
37
- `Saving & Loading Model for
38
Inference <#saving-loading-model-for-inference>`__
39
- `Saving & Loading a General
40
Checkpoint <#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training>`__
41
- `Saving Multiple Models in One
42
File <#saving-multiple-models-in-one-file>`__
43
- `Warmstarting Model Using Parameters from a Different
44
Model <#warmstarting-model-using-parameters-from-a-different-model>`__
45
- `Saving & Loading Model Across
46
Devices <#saving-loading-model-across-devices>`__
47
48
"""
49
50
51
######################################################################
52
# What is a ``state_dict``?
53
# -------------------------
54
#
55
# In PyTorch, the learnable parameters (i.e. weights and biases) of an
56
# ``torch.nn.Module`` model are contained in the model’s *parameters*
57
# (accessed with ``model.parameters()``). A *state_dict* is simply a
58
# Python dictionary object that maps each layer to its parameter tensor.
59
# Note that only layers with learnable parameters (convolutional layers,
60
# linear layers, etc.) and registered buffers (batchnorm's running_mean)
61
# have entries in the model’s *state_dict*. Optimizer
62
# objects (``torch.optim``) also have a *state_dict*, which contains
63
# information about the optimizer's state, as well as the hyperparameters
64
# used.
65
#
66
# Because *state_dict* objects are Python dictionaries, they can be easily
67
# saved, updated, altered, and restored, adding a great deal of modularity
68
# to PyTorch models and optimizers.
69
#
70
# Example:
71
# ^^^^^^^^
72
#
73
# Let’s take a look at the *state_dict* from the simple model used in the
74
# `Training a
75
# classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py>`__
76
# tutorial.
77
#
78
# .. code:: python
79
#
80
# # Define model
81
# class TheModelClass(nn.Module):
82
# def __init__(self):
83
# super(TheModelClass, self).__init__()
84
# self.conv1 = nn.Conv2d(3, 6, 5)
85
# self.pool = nn.MaxPool2d(2, 2)
86
# self.conv2 = nn.Conv2d(6, 16, 5)
87
# self.fc1 = nn.Linear(16 * 5 * 5, 120)
88
# self.fc2 = nn.Linear(120, 84)
89
# self.fc3 = nn.Linear(84, 10)
90
#
91
# def forward(self, x):
92
# x = self.pool(F.relu(self.conv1(x)))
93
# x = self.pool(F.relu(self.conv2(x)))
94
# x = x.view(-1, 16 * 5 * 5)
95
# x = F.relu(self.fc1(x))
96
# x = F.relu(self.fc2(x))
97
# x = self.fc3(x)
98
# return x
99
#
100
# # Initialize model
101
# model = TheModelClass()
102
#
103
# # Initialize optimizer
104
# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
105
#
106
# # Print model's state_dict
107
# print("Model's state_dict:")
108
# for param_tensor in model.state_dict():
109
# print(param_tensor, "\t", model.state_dict()[param_tensor].size())
110
#
111
# # Print optimizer's state_dict
112
# print("Optimizer's state_dict:")
113
# for var_name in optimizer.state_dict():
114
# print(var_name, "\t", optimizer.state_dict()[var_name])
115
#
116
# **Output:**
117
#
118
# .. code-block:: sh
119
#
120
# Model's state_dict:
121
# conv1.weight torch.Size([6, 3, 5, 5])
122
# conv1.bias torch.Size([6])
123
# conv2.weight torch.Size([16, 6, 5, 5])
124
# conv2.bias torch.Size([16])
125
# fc1.weight torch.Size([120, 400])
126
# fc1.bias torch.Size([120])
127
# fc2.weight torch.Size([84, 120])
128
# fc2.bias torch.Size([84])
129
# fc3.weight torch.Size([10, 84])
130
# fc3.bias torch.Size([10])
131
#
132
# Optimizer's state_dict:
133
# state {}
134
# param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
135
#
136
137
138
######################################################################
139
# Saving & Loading Model for Inference
140
# ------------------------------------
141
#
142
# Save/Load ``state_dict`` (Recommended)
143
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
144
#
145
# **Save:**
146
#
147
# .. code:: python
148
#
149
# torch.save(model.state_dict(), PATH)
150
#
151
# **Load:**
152
#
153
# .. code:: python
154
#
155
# model = TheModelClass(*args, **kwargs)
156
# model.load_state_dict(torch.load(PATH, weights_only=True))
157
# model.eval()
158
#
159
# .. note::
160
# The 1.6 release of PyTorch switched ``torch.save`` to use a new
161
# zip file-based format. ``torch.load`` still retains the ability to
162
# load files in the old format. If for any reason you want ``torch.save``
163
# to use the old format, pass the ``kwarg`` parameter ``_use_new_zipfile_serialization=False``.
164
#
165
# When saving a model for inference, it is only necessary to save the
166
# trained model’s learned parameters. Saving the model’s *state_dict* with
167
# the ``torch.save()`` function will give you the most flexibility for
168
# restoring the model later, which is why it is the recommended method for
169
# saving models.
170
#
171
# A common PyTorch convention is to save models using either a ``.pt`` or
172
# ``.pth`` file extension.
173
#
174
# Remember that you must call ``model.eval()`` to set dropout and batch
175
# normalization layers to evaluation mode before running inference.
176
# Failing to do this will yield inconsistent inference results.
177
#
178
# .. note::
179
#
180
# Notice that the ``load_state_dict()`` function takes a dictionary
181
# object, NOT a path to a saved object. This means that you must
182
# deserialize the saved *state_dict* before you pass it to the
183
# ``load_state_dict()`` function. For example, you CANNOT load using
184
# ``model.load_state_dict(PATH)``.
185
#
186
# .. note::
187
#
188
# If you only plan to keep the best performing model (according to the
189
# acquired validation loss), don't forget that ``best_model_state = model.state_dict()``
190
# returns a reference to the state and not its copy! You must serialize
191
# ``best_model_state`` or use ``best_model_state = deepcopy(model.state_dict())`` otherwise
192
# your best ``best_model_state`` will keep getting updated by the subsequent training
193
# iterations. As a result, the final model state will be the state of the overfitted model.
194
#
195
# Save/Load Entire Model
196
# ^^^^^^^^^^^^^^^^^^^^^^
197
#
198
# **Save:**
199
#
200
# .. code:: python
201
#
202
# torch.save(model, PATH)
203
#
204
# **Load:**
205
#
206
# .. code:: python
207
#
208
# # Model class must be defined somewhere
209
# model = torch.load(PATH, weights_only=False)
210
# model.eval()
211
#
212
# This save/load process uses the most intuitive syntax and involves the
213
# least amount of code. Saving a model in this way will save the entire
214
# module using Python’s
215
# `pickle <https://docs.python.org/3/library/pickle.html>`__ module. The
216
# disadvantage of this approach is that the serialized data is bound to
217
# the specific classes and the exact directory structure used when the
218
# model is saved. The reason for this is because pickle does not save the
219
# model class itself. Rather, it saves a path to the file containing the
220
# class, which is used during load time. Because of this, your code can
221
# break in various ways when used in other projects or after refactors.
222
#
223
# A common PyTorch convention is to save models using either a ``.pt`` or
224
# ``.pth`` file extension.
225
#
226
# Remember that you must call ``model.eval()`` to set dropout and batch
227
# normalization layers to evaluation mode before running inference.
228
# Failing to do this will yield inconsistent inference results.
229
#
230
# Export/Load Model in TorchScript Format
231
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
232
#
233
# One common way to do inference with a trained model is to use
234
# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__, an intermediate
235
# representation of a PyTorch model that can be run in Python as well as in a
236
# high performance environment like C++. TorchScript is actually the recommended model format
237
# for scaled inference and deployment.
238
#
239
# .. note::
240
# Using the TorchScript format, you will be able to load the exported model and
241
# run inference without defining the model class.
242
#
243
# **Export:**
244
#
245
# .. code:: python
246
#
247
# model_scripted = torch.jit.script(model) # Export to TorchScript
248
# model_scripted.save('model_scripted.pt') # Save
249
#
250
# **Load:**
251
#
252
# .. code:: python
253
#
254
# model = torch.jit.load('model_scripted.pt')
255
# model.eval()
256
#
257
# Remember that you must call ``model.eval()`` to set dropout and batch
258
# normalization layers to evaluation mode before running inference.
259
# Failing to do this will yield inconsistent inference results.
260
#
261
# For more information on TorchScript, feel free to visit the dedicated
262
# `tutorials <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`__.
263
# You will get familiar with the tracing conversion and learn how to
264
# run a TorchScript module in a `C++ environment <https://pytorch.org/tutorials/advanced/cpp_export.html>`__.
265
266
267
268
######################################################################
269
# Saving & Loading a General Checkpoint for Inference and/or Resuming Training
270
# ----------------------------------------------------------------------------
271
#
272
# Save:
273
# ^^^^^
274
#
275
# .. code:: python
276
#
277
# torch.save({
278
# 'epoch': epoch,
279
# 'model_state_dict': model.state_dict(),
280
# 'optimizer_state_dict': optimizer.state_dict(),
281
# 'loss': loss,
282
# ...
283
# }, PATH)
284
#
285
# Load:
286
# ^^^^^
287
#
288
# .. code:: python
289
#
290
# model = TheModelClass(*args, **kwargs)
291
# optimizer = TheOptimizerClass(*args, **kwargs)
292
#
293
# checkpoint = torch.load(PATH, weights_only=True)
294
# model.load_state_dict(checkpoint['model_state_dict'])
295
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
296
# epoch = checkpoint['epoch']
297
# loss = checkpoint['loss']
298
#
299
# model.eval()
300
# # - or -
301
# model.train()
302
#
303
# When saving a general checkpoint, to be used for either inference or
304
# resuming training, you must save more than just the model’s
305
# *state_dict*. It is important to also save the optimizer's *state_dict*,
306
# as this contains buffers and parameters that are updated as the model
307
# trains. Other items that you may want to save are the epoch you left off
308
# on, the latest recorded training loss, external ``torch.nn.Embedding``
309
# layers, etc. As a result, such a checkpoint is often 2~3 times larger
310
# than the model alone.
311
#
312
# To save multiple components, organize them in a dictionary and use
313
# ``torch.save()`` to serialize the dictionary. A common PyTorch
314
# convention is to save these checkpoints using the ``.tar`` file
315
# extension.
316
#
317
# To load the items, first initialize the model and optimizer, then load
318
# the dictionary locally using ``torch.load()``. From here, you can easily
319
# access the saved items by simply querying the dictionary as you would
320
# expect.
321
#
322
# Remember that you must call ``model.eval()`` to set dropout and batch
323
# normalization layers to evaluation mode before running inference.
324
# Failing to do this will yield inconsistent inference results. If you
325
# wish to resuming training, call ``model.train()`` to ensure these layers
326
# are in training mode.
327
#
328
329
330
######################################################################
331
# Saving Multiple Models in One File
332
# ----------------------------------
333
#
334
# Save:
335
# ^^^^^
336
#
337
# .. code:: python
338
#
339
# torch.save({
340
# 'modelA_state_dict': modelA.state_dict(),
341
# 'modelB_state_dict': modelB.state_dict(),
342
# 'optimizerA_state_dict': optimizerA.state_dict(),
343
# 'optimizerB_state_dict': optimizerB.state_dict(),
344
# ...
345
# }, PATH)
346
#
347
# Load:
348
# ^^^^^
349
#
350
# .. code:: python
351
#
352
# modelA = TheModelAClass(*args, **kwargs)
353
# modelB = TheModelBClass(*args, **kwargs)
354
# optimizerA = TheOptimizerAClass(*args, **kwargs)
355
# optimizerB = TheOptimizerBClass(*args, **kwargs)
356
#
357
# checkpoint = torch.load(PATH, weights_only=True)
358
# modelA.load_state_dict(checkpoint['modelA_state_dict'])
359
# modelB.load_state_dict(checkpoint['modelB_state_dict'])
360
# optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
361
# optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
362
#
363
# modelA.eval()
364
# modelB.eval()
365
# # - or -
366
# modelA.train()
367
# modelB.train()
368
#
369
# When saving a model comprised of multiple ``torch.nn.Modules``, such as
370
# a GAN, a sequence-to-sequence model, or an ensemble of models, you
371
# follow the same approach as when you are saving a general checkpoint. In
372
# other words, save a dictionary of each model’s *state_dict* and
373
# corresponding optimizer. As mentioned before, you can save any other
374
# items that may aid you in resuming training by simply appending them to
375
# the dictionary.
376
#
377
# A common PyTorch convention is to save these checkpoints using the
378
# ``.tar`` file extension.
379
#
380
# To load the models, first initialize the models and optimizers, then
381
# load the dictionary locally using ``torch.load()``. From here, you can
382
# easily access the saved items by simply querying the dictionary as you
383
# would expect.
384
#
385
# Remember that you must call ``model.eval()`` to set dropout and batch
386
# normalization layers to evaluation mode before running inference.
387
# Failing to do this will yield inconsistent inference results. If you
388
# wish to resuming training, call ``model.train()`` to set these layers to
389
# training mode.
390
#
391
392
393
######################################################################
394
# Warmstarting Model Using Parameters from a Different Model
395
# ----------------------------------------------------------
396
#
397
# Save:
398
# ^^^^^
399
#
400
# .. code:: python
401
#
402
# torch.save(modelA.state_dict(), PATH)
403
#
404
# Load:
405
# ^^^^^
406
#
407
# .. code:: python
408
#
409
# modelB = TheModelBClass(*args, **kwargs)
410
# modelB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)
411
#
412
# Partially loading a model or loading a partial model are common
413
# scenarios when transfer learning or training a new complex model.
414
# Leveraging trained parameters, even if only a few are usable, will help
415
# to warmstart the training process and hopefully help your model converge
416
# much faster than training from scratch.
417
#
418
# Whether you are loading from a partial *state_dict*, which is missing
419
# some keys, or loading a *state_dict* with more keys than the model that
420
# you are loading into, you can set the ``strict`` argument to **False**
421
# in the ``load_state_dict()`` function to ignore non-matching keys.
422
#
423
# If you want to load parameters from one layer to another, but some keys
424
# do not match, simply change the name of the parameter keys in the
425
# *state_dict* that you are loading to match the keys in the model that
426
# you are loading into.
427
#
428
429
430
######################################################################
431
# Saving & Loading Model Across Devices
432
# -------------------------------------
433
#
434
# Save on GPU, Load on CPU
435
# ^^^^^^^^^^^^^^^^^^^^^^^^
436
#
437
# **Save:**
438
#
439
# .. code:: python
440
#
441
# torch.save(model.state_dict(), PATH)
442
#
443
# **Load:**
444
#
445
# .. code:: python
446
#
447
# device = torch.device('cpu')
448
# model = TheModelClass(*args, **kwargs)
449
# model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))
450
#
451
# When loading a model on a CPU that was trained with a GPU, pass
452
# ``torch.device('cpu')`` to the ``map_location`` argument in the
453
# ``torch.load()`` function. In this case, the storages underlying the
454
# tensors are dynamically remapped to the CPU device using the
455
# ``map_location`` argument.
456
#
457
# Save on GPU, Load on GPU
458
# ^^^^^^^^^^^^^^^^^^^^^^^^
459
#
460
# **Save:**
461
#
462
# .. code:: python
463
#
464
# torch.save(model.state_dict(), PATH)
465
#
466
# **Load:**
467
#
468
# .. code:: python
469
#
470
# device = torch.device("cuda")
471
# model = TheModelClass(*args, **kwargs)
472
# model.load_state_dict(torch.load(PATH, weights_only=True))
473
# model.to(device)
474
# # Make sure to call input = input.to(device) on any input tensors that you feed to the model
475
#
476
# When loading a model on a GPU that was trained and saved on GPU, simply
477
# convert the initialized ``model`` to a CUDA optimized model using
478
# ``model.to(torch.device('cuda'))``. Also, be sure to use the
479
# ``.to(torch.device('cuda'))`` function on all model inputs to prepare
480
# the data for the model. Note that calling ``my_tensor.to(device)``
481
# returns a new copy of ``my_tensor`` on GPU. It does NOT overwrite
482
# ``my_tensor``. Therefore, remember to manually overwrite tensors:
483
# ``my_tensor = my_tensor.to(torch.device('cuda'))``.
484
#
485
# Save on CPU, Load on GPU
486
# ^^^^^^^^^^^^^^^^^^^^^^^^
487
#
488
# **Save:**
489
#
490
# .. code:: python
491
#
492
# torch.save(model.state_dict(), PATH)
493
#
494
# **Load:**
495
#
496
# .. code:: python
497
#
498
# device = torch.device("cuda")
499
# model = TheModelClass(*args, **kwargs)
500
# model.load_state_dict(torch.load(PATH, weights_only=True, map_location="cuda:0")) # Choose whatever GPU device number you want
501
# model.to(device)
502
# # Make sure to call input = input.to(device) on any input tensors that you feed to the model
503
#
504
# When loading a model on a GPU that was trained and saved on CPU, set the
505
# ``map_location`` argument in the ``torch.load()`` function to
506
# ``cuda:device_id``. This loads the model to a given GPU device. Next, be
507
# sure to call ``model.to(torch.device('cuda'))`` to convert the model’s
508
# parameter tensors to CUDA tensors. Finally, be sure to use the
509
# ``.to(torch.device('cuda'))`` function on all model inputs to prepare
510
# the data for the CUDA optimized model. Note that calling
511
# ``my_tensor.to(device)`` returns a new copy of ``my_tensor`` on GPU. It
512
# does NOT overwrite ``my_tensor``. Therefore, remember to manually
513
# overwrite tensors: ``my_tensor = my_tensor.to(torch.device('cuda'))``.
514
#
515
# Saving ``torch.nn.DataParallel`` Models
516
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
517
#
518
# **Save:**
519
#
520
# .. code:: python
521
#
522
# torch.save(model.module.state_dict(), PATH)
523
#
524
# **Load:**
525
#
526
# .. code:: python
527
#
528
# # Load to whatever device you want
529
#
530
# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU
531
# utilization. To save a ``DataParallel`` model generically, save the
532
# ``model.module.state_dict()``. This way, you have the flexibility to
533
# load the model any way you want to any device you want.
534
#
535
536