Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/saving_loading_models.py
Views: 712
# -*- coding: utf-8 -*-1"""2Saving and Loading Models3=========================4**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_56This document provides solutions to a variety of use cases regarding the7saving and loading of PyTorch models. Feel free to read the whole8document, or just skip to the code you need for a desired use case.910When it comes to saving and loading models, there are three core11functions to be familiar with:12131) `torch.save <https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save>`__:14Saves a serialized object to disk. This function uses Python’s15`pickle <https://docs.python.org/3/library/pickle.html>`__ utility16for serialization. Models, tensors, and dictionaries of all kinds of17objects can be saved using this function.18192) `torch.load <https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load>`__:20Uses `pickle <https://docs.python.org/3/library/pickle.html>`__\ ’s21unpickling facilities to deserialize pickled object files to memory.22This function also facilitates the device to load the data into (see23`Saving & Loading Model Across24Devices <#saving-loading-model-across-devices>`__).25263) `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict>`__:27Loads a model’s parameter dictionary using a deserialized28*state_dict*. For more information on *state_dict*, see `What is a29state_dict? <#what-is-a-state-dict>`__.30313233**Contents:**3435- `What is a state_dict? <#what-is-a-state-dict>`__36- `Saving & Loading Model for37Inference <#saving-loading-model-for-inference>`__38- `Saving & Loading a General39Checkpoint <#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training>`__40- `Saving Multiple Models in One41File <#saving-multiple-models-in-one-file>`__42- `Warmstarting Model Using Parameters from a Different43Model <#warmstarting-model-using-parameters-from-a-different-model>`__44- `Saving & Loading Model Across45Devices <#saving-loading-model-across-devices>`__4647"""484950######################################################################51# What is a ``state_dict``?52# -------------------------53#54# In PyTorch, the learnable parameters (i.e. weights and biases) of an55# ``torch.nn.Module`` model are contained in the model’s *parameters*56# (accessed with ``model.parameters()``). A *state_dict* is simply a57# Python dictionary object that maps each layer to its parameter tensor.58# Note that only layers with learnable parameters (convolutional layers,59# linear layers, etc.) and registered buffers (batchnorm's running_mean)60# have entries in the model’s *state_dict*. Optimizer61# objects (``torch.optim``) also have a *state_dict*, which contains62# information about the optimizer's state, as well as the hyperparameters63# used.64#65# Because *state_dict* objects are Python dictionaries, they can be easily66# saved, updated, altered, and restored, adding a great deal of modularity67# to PyTorch models and optimizers.68#69# Example:70# ^^^^^^^^71#72# Let’s take a look at the *state_dict* from the simple model used in the73# `Training a74# classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py>`__75# tutorial.76#77# .. code:: python78#79# # Define model80# class TheModelClass(nn.Module):81# def __init__(self):82# super(TheModelClass, self).__init__()83# self.conv1 = nn.Conv2d(3, 6, 5)84# self.pool = nn.MaxPool2d(2, 2)85# self.conv2 = nn.Conv2d(6, 16, 5)86# self.fc1 = nn.Linear(16 * 5 * 5, 120)87# self.fc2 = nn.Linear(120, 84)88# self.fc3 = nn.Linear(84, 10)89#90# def forward(self, x):91# x = self.pool(F.relu(self.conv1(x)))92# x = self.pool(F.relu(self.conv2(x)))93# x = x.view(-1, 16 * 5 * 5)94# x = F.relu(self.fc1(x))95# x = F.relu(self.fc2(x))96# x = self.fc3(x)97# return x98#99# # Initialize model100# model = TheModelClass()101#102# # Initialize optimizer103# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)104#105# # Print model's state_dict106# print("Model's state_dict:")107# for param_tensor in model.state_dict():108# print(param_tensor, "\t", model.state_dict()[param_tensor].size())109#110# # Print optimizer's state_dict111# print("Optimizer's state_dict:")112# for var_name in optimizer.state_dict():113# print(var_name, "\t", optimizer.state_dict()[var_name])114#115# **Output:**116#117# .. code-block:: sh118#119# Model's state_dict:120# conv1.weight torch.Size([6, 3, 5, 5])121# conv1.bias torch.Size([6])122# conv2.weight torch.Size([16, 6, 5, 5])123# conv2.bias torch.Size([16])124# fc1.weight torch.Size([120, 400])125# fc1.bias torch.Size([120])126# fc2.weight torch.Size([84, 120])127# fc2.bias torch.Size([84])128# fc3.weight torch.Size([10, 84])129# fc3.bias torch.Size([10])130#131# Optimizer's state_dict:132# state {}133# param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]134#135136137######################################################################138# Saving & Loading Model for Inference139# ------------------------------------140#141# Save/Load ``state_dict`` (Recommended)142# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^143#144# **Save:**145#146# .. code:: python147#148# torch.save(model.state_dict(), PATH)149#150# **Load:**151#152# .. code:: python153#154# model = TheModelClass(*args, **kwargs)155# model.load_state_dict(torch.load(PATH, weights_only=True))156# model.eval()157#158# .. note::159# The 1.6 release of PyTorch switched ``torch.save`` to use a new160# zip file-based format. ``torch.load`` still retains the ability to161# load files in the old format. If for any reason you want ``torch.save``162# to use the old format, pass the ``kwarg`` parameter ``_use_new_zipfile_serialization=False``.163#164# When saving a model for inference, it is only necessary to save the165# trained model’s learned parameters. Saving the model’s *state_dict* with166# the ``torch.save()`` function will give you the most flexibility for167# restoring the model later, which is why it is the recommended method for168# saving models.169#170# A common PyTorch convention is to save models using either a ``.pt`` or171# ``.pth`` file extension.172#173# Remember that you must call ``model.eval()`` to set dropout and batch174# normalization layers to evaluation mode before running inference.175# Failing to do this will yield inconsistent inference results.176#177# .. note::178#179# Notice that the ``load_state_dict()`` function takes a dictionary180# object, NOT a path to a saved object. This means that you must181# deserialize the saved *state_dict* before you pass it to the182# ``load_state_dict()`` function. For example, you CANNOT load using183# ``model.load_state_dict(PATH)``.184#185# .. note::186#187# If you only plan to keep the best performing model (according to the188# acquired validation loss), don't forget that ``best_model_state = model.state_dict()``189# returns a reference to the state and not its copy! You must serialize190# ``best_model_state`` or use ``best_model_state = deepcopy(model.state_dict())`` otherwise191# your best ``best_model_state`` will keep getting updated by the subsequent training192# iterations. As a result, the final model state will be the state of the overfitted model.193#194# Save/Load Entire Model195# ^^^^^^^^^^^^^^^^^^^^^^196#197# **Save:**198#199# .. code:: python200#201# torch.save(model, PATH)202#203# **Load:**204#205# .. code:: python206#207# # Model class must be defined somewhere208# model = torch.load(PATH, weights_only=False)209# model.eval()210#211# This save/load process uses the most intuitive syntax and involves the212# least amount of code. Saving a model in this way will save the entire213# module using Python’s214# `pickle <https://docs.python.org/3/library/pickle.html>`__ module. The215# disadvantage of this approach is that the serialized data is bound to216# the specific classes and the exact directory structure used when the217# model is saved. The reason for this is because pickle does not save the218# model class itself. Rather, it saves a path to the file containing the219# class, which is used during load time. Because of this, your code can220# break in various ways when used in other projects or after refactors.221#222# A common PyTorch convention is to save models using either a ``.pt`` or223# ``.pth`` file extension.224#225# Remember that you must call ``model.eval()`` to set dropout and batch226# normalization layers to evaluation mode before running inference.227# Failing to do this will yield inconsistent inference results.228#229# Export/Load Model in TorchScript Format230# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^231#232# One common way to do inference with a trained model is to use233# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__, an intermediate234# representation of a PyTorch model that can be run in Python as well as in a235# high performance environment like C++. TorchScript is actually the recommended model format236# for scaled inference and deployment.237#238# .. note::239# Using the TorchScript format, you will be able to load the exported model and240# run inference without defining the model class.241#242# **Export:**243#244# .. code:: python245#246# model_scripted = torch.jit.script(model) # Export to TorchScript247# model_scripted.save('model_scripted.pt') # Save248#249# **Load:**250#251# .. code:: python252#253# model = torch.jit.load('model_scripted.pt')254# model.eval()255#256# Remember that you must call ``model.eval()`` to set dropout and batch257# normalization layers to evaluation mode before running inference.258# Failing to do this will yield inconsistent inference results.259#260# For more information on TorchScript, feel free to visit the dedicated261# `tutorials <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`__.262# You will get familiar with the tracing conversion and learn how to263# run a TorchScript module in a `C++ environment <https://pytorch.org/tutorials/advanced/cpp_export.html>`__.264265266267######################################################################268# Saving & Loading a General Checkpoint for Inference and/or Resuming Training269# ----------------------------------------------------------------------------270#271# Save:272# ^^^^^273#274# .. code:: python275#276# torch.save({277# 'epoch': epoch,278# 'model_state_dict': model.state_dict(),279# 'optimizer_state_dict': optimizer.state_dict(),280# 'loss': loss,281# ...282# }, PATH)283#284# Load:285# ^^^^^286#287# .. code:: python288#289# model = TheModelClass(*args, **kwargs)290# optimizer = TheOptimizerClass(*args, **kwargs)291#292# checkpoint = torch.load(PATH, weights_only=True)293# model.load_state_dict(checkpoint['model_state_dict'])294# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])295# epoch = checkpoint['epoch']296# loss = checkpoint['loss']297#298# model.eval()299# # - or -300# model.train()301#302# When saving a general checkpoint, to be used for either inference or303# resuming training, you must save more than just the model’s304# *state_dict*. It is important to also save the optimizer's *state_dict*,305# as this contains buffers and parameters that are updated as the model306# trains. Other items that you may want to save are the epoch you left off307# on, the latest recorded training loss, external ``torch.nn.Embedding``308# layers, etc. As a result, such a checkpoint is often 2~3 times larger309# than the model alone.310#311# To save multiple components, organize them in a dictionary and use312# ``torch.save()`` to serialize the dictionary. A common PyTorch313# convention is to save these checkpoints using the ``.tar`` file314# extension.315#316# To load the items, first initialize the model and optimizer, then load317# the dictionary locally using ``torch.load()``. From here, you can easily318# access the saved items by simply querying the dictionary as you would319# expect.320#321# Remember that you must call ``model.eval()`` to set dropout and batch322# normalization layers to evaluation mode before running inference.323# Failing to do this will yield inconsistent inference results. If you324# wish to resuming training, call ``model.train()`` to ensure these layers325# are in training mode.326#327328329######################################################################330# Saving Multiple Models in One File331# ----------------------------------332#333# Save:334# ^^^^^335#336# .. code:: python337#338# torch.save({339# 'modelA_state_dict': modelA.state_dict(),340# 'modelB_state_dict': modelB.state_dict(),341# 'optimizerA_state_dict': optimizerA.state_dict(),342# 'optimizerB_state_dict': optimizerB.state_dict(),343# ...344# }, PATH)345#346# Load:347# ^^^^^348#349# .. code:: python350#351# modelA = TheModelAClass(*args, **kwargs)352# modelB = TheModelBClass(*args, **kwargs)353# optimizerA = TheOptimizerAClass(*args, **kwargs)354# optimizerB = TheOptimizerBClass(*args, **kwargs)355#356# checkpoint = torch.load(PATH, weights_only=True)357# modelA.load_state_dict(checkpoint['modelA_state_dict'])358# modelB.load_state_dict(checkpoint['modelB_state_dict'])359# optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])360# optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])361#362# modelA.eval()363# modelB.eval()364# # - or -365# modelA.train()366# modelB.train()367#368# When saving a model comprised of multiple ``torch.nn.Modules``, such as369# a GAN, a sequence-to-sequence model, or an ensemble of models, you370# follow the same approach as when you are saving a general checkpoint. In371# other words, save a dictionary of each model’s *state_dict* and372# corresponding optimizer. As mentioned before, you can save any other373# items that may aid you in resuming training by simply appending them to374# the dictionary.375#376# A common PyTorch convention is to save these checkpoints using the377# ``.tar`` file extension.378#379# To load the models, first initialize the models and optimizers, then380# load the dictionary locally using ``torch.load()``. From here, you can381# easily access the saved items by simply querying the dictionary as you382# would expect.383#384# Remember that you must call ``model.eval()`` to set dropout and batch385# normalization layers to evaluation mode before running inference.386# Failing to do this will yield inconsistent inference results. If you387# wish to resuming training, call ``model.train()`` to set these layers to388# training mode.389#390391392######################################################################393# Warmstarting Model Using Parameters from a Different Model394# ----------------------------------------------------------395#396# Save:397# ^^^^^398#399# .. code:: python400#401# torch.save(modelA.state_dict(), PATH)402#403# Load:404# ^^^^^405#406# .. code:: python407#408# modelB = TheModelBClass(*args, **kwargs)409# modelB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)410#411# Partially loading a model or loading a partial model are common412# scenarios when transfer learning or training a new complex model.413# Leveraging trained parameters, even if only a few are usable, will help414# to warmstart the training process and hopefully help your model converge415# much faster than training from scratch.416#417# Whether you are loading from a partial *state_dict*, which is missing418# some keys, or loading a *state_dict* with more keys than the model that419# you are loading into, you can set the ``strict`` argument to **False**420# in the ``load_state_dict()`` function to ignore non-matching keys.421#422# If you want to load parameters from one layer to another, but some keys423# do not match, simply change the name of the parameter keys in the424# *state_dict* that you are loading to match the keys in the model that425# you are loading into.426#427428429######################################################################430# Saving & Loading Model Across Devices431# -------------------------------------432#433# Save on GPU, Load on CPU434# ^^^^^^^^^^^^^^^^^^^^^^^^435#436# **Save:**437#438# .. code:: python439#440# torch.save(model.state_dict(), PATH)441#442# **Load:**443#444# .. code:: python445#446# device = torch.device('cpu')447# model = TheModelClass(*args, **kwargs)448# model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))449#450# When loading a model on a CPU that was trained with a GPU, pass451# ``torch.device('cpu')`` to the ``map_location`` argument in the452# ``torch.load()`` function. In this case, the storages underlying the453# tensors are dynamically remapped to the CPU device using the454# ``map_location`` argument.455#456# Save on GPU, Load on GPU457# ^^^^^^^^^^^^^^^^^^^^^^^^458#459# **Save:**460#461# .. code:: python462#463# torch.save(model.state_dict(), PATH)464#465# **Load:**466#467# .. code:: python468#469# device = torch.device("cuda")470# model = TheModelClass(*args, **kwargs)471# model.load_state_dict(torch.load(PATH, weights_only=True))472# model.to(device)473# # Make sure to call input = input.to(device) on any input tensors that you feed to the model474#475# When loading a model on a GPU that was trained and saved on GPU, simply476# convert the initialized ``model`` to a CUDA optimized model using477# ``model.to(torch.device('cuda'))``. Also, be sure to use the478# ``.to(torch.device('cuda'))`` function on all model inputs to prepare479# the data for the model. Note that calling ``my_tensor.to(device)``480# returns a new copy of ``my_tensor`` on GPU. It does NOT overwrite481# ``my_tensor``. Therefore, remember to manually overwrite tensors:482# ``my_tensor = my_tensor.to(torch.device('cuda'))``.483#484# Save on CPU, Load on GPU485# ^^^^^^^^^^^^^^^^^^^^^^^^486#487# **Save:**488#489# .. code:: python490#491# torch.save(model.state_dict(), PATH)492#493# **Load:**494#495# .. code:: python496#497# device = torch.device("cuda")498# model = TheModelClass(*args, **kwargs)499# model.load_state_dict(torch.load(PATH, weights_only=True, map_location="cuda:0")) # Choose whatever GPU device number you want500# model.to(device)501# # Make sure to call input = input.to(device) on any input tensors that you feed to the model502#503# When loading a model on a GPU that was trained and saved on CPU, set the504# ``map_location`` argument in the ``torch.load()`` function to505# ``cuda:device_id``. This loads the model to a given GPU device. Next, be506# sure to call ``model.to(torch.device('cuda'))`` to convert the model’s507# parameter tensors to CUDA tensors. Finally, be sure to use the508# ``.to(torch.device('cuda'))`` function on all model inputs to prepare509# the data for the CUDA optimized model. Note that calling510# ``my_tensor.to(device)`` returns a new copy of ``my_tensor`` on GPU. It511# does NOT overwrite ``my_tensor``. Therefore, remember to manually512# overwrite tensors: ``my_tensor = my_tensor.to(torch.device('cuda'))``.513#514# Saving ``torch.nn.DataParallel`` Models515# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^516#517# **Save:**518#519# .. code:: python520#521# torch.save(model.module.state_dict(), PATH)522#523# **Load:**524#525# .. code:: python526#527# # Load to whatever device you want528#529# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU530# utilization. To save a ``DataParallel`` model generically, save the531# ``model.module.state_dict()``. This way, you have the flexibility to532# load the model any way you want to any device you want.533#534535536