Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/basics/saveloadrun_tutorial.py
Views: 713
"""1`Learn the Basics <intro.html>`_ ||2`Quickstart <quickstart_tutorial.html>`_ ||3`Tensors <tensorqs_tutorial.html>`_ ||4`Datasets & DataLoaders <data_tutorial.html>`_ ||5`Transforms <transforms_tutorial.html>`_ ||6`Build Model <buildmodel_tutorial.html>`_ ||7`Autograd <autogradqs_tutorial.html>`_ ||8`Optimization <optimization_tutorial.html>`_ ||9**Save & Load Model**1011Save and Load the Model12============================1314In this section we will look at how to persist model state with saving, loading and running model predictions.15"""1617import torch18import torchvision.models as models192021#######################################################################22# Saving and Loading Model Weights23# --------------------------------24# PyTorch models store the learned parameters in an internal25# state dictionary, called ``state_dict``. These can be persisted via the ``torch.save``26# method:2728model = models.vgg16(weights='IMAGENET1K_V1')29torch.save(model.state_dict(), 'model_weights.pth')3031##########################32# To load model weights, you need to create an instance of the same model first, and then load the parameters33# using ``load_state_dict()`` method.34#35# In the code below, we set ``weights_only=True`` to limit the36# functions executed during unpickling to only those necessary for37# loading weights. Using ``weights_only=True`` is considered38# a best practice when loading weights.3940model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model41model.load_state_dict(torch.load('model_weights.pth', weights_only=True))42model.eval()4344###########################45# .. note:: be sure to call ``model.eval()`` method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.4647#######################################################################48# Saving and Loading Models with Shapes49# -------------------------------------50# When loading model weights, we needed to instantiate the model class first, because the class51# defines the structure of a network. We might want to save the structure of this class together with52# the model, in which case we can pass ``model`` (and not ``model.state_dict()``) to the saving function:5354torch.save(model, 'model.pth')5556########################57# We can then load the model as demonstrated below.58#59# As described in `Saving and loading torch.nn.Modules <https://pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules>`_,60# saving ``state_dict`` is considered the best practice. However,61# below we use ``weights_only=False`` because this involves loading the62# model, which is a legacy use case for ``torch.save``.6364model = torch.load('model.pth', weights_only=False),6566########################67# .. note:: This approach uses Python `pickle <https://docs.python.org/3/library/pickle.html>`_ module when serializing the model, thus it relies on the actual class definition to be available when loading the model.6869#######################70# Related Tutorials71# -----------------72# - `Saving and Loading a General Checkpoint in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html>`_73# - `Tips for loading an nn.Module from a checkpoint <https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html?highlight=loading%20nn%20module%20from%20checkpoint>`_747576