.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "beginner/basics/saveloadrun_tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_beginner_basics_saveloadrun_tutorial.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_beginner_basics_saveloadrun_tutorial.py:


`Learn the Basics <intro.html>`_ ||
`Quickstart <quickstart_tutorial.html>`_ ||
`Tensors <tensorqs_tutorial.html>`_ ||
`Datasets & DataLoaders <data_tutorial.html>`_ ||
`Transforms <transforms_tutorial.html>`_ ||
`Build Model <buildmodel_tutorial.html>`_ ||
`Autograd <autogradqs_tutorial.html>`_ ||
`Optimization <optimization_tutorial.html>`_ ||
**Save & Load Model**

Save and Load the Model
============================

In this section we will look at how to persist model state with saving, loading and running model predictions.

.. GENERATED FROM PYTHON SOURCE LINES 17-22

.. code-block:: default


    import torch
    import torchvision.models as models



.. GENERATED FROM PYTHON SOURCE LINES 23-28

Saving and Loading Model Weights
--------------------------------
PyTorch models store the learned parameters in an internal
state dictionary, called ``state_dict``. These can be persisted via the ``torch.save``
method:

.. GENERATED FROM PYTHON SOURCE LINES 28-32

.. code-block:: default


    model = models.vgg16(pretrained=True)
    torch.save(model.state_dict(), 'model_weights.pth')


.. GENERATED FROM PYTHON SOURCE LINES 33-35

To load model weights, you need to create an instance of the same model first, and then load the parameters
using ``load_state_dict()`` method.

.. GENERATED FROM PYTHON SOURCE LINES 35-40

.. code-block:: default


    model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
    model.load_state_dict(torch.load('model_weights.pth'))
    model.eval()


.. GENERATED FROM PYTHON SOURCE LINES 41-42

.. note:: be sure to call ``model.eval()`` method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

.. GENERATED FROM PYTHON SOURCE LINES 44-49

Saving and Loading Models with Shapes
-------------------------------------
When loading model weights, we needed to instantiate the model class first, because the class
defines the structure of a network. We might want to save the structure of this class together with
the model, in which case we can pass ``model`` (and not ``model.state_dict()``) to the saving function:

.. GENERATED FROM PYTHON SOURCE LINES 49-52

.. code-block:: default


    torch.save(model, 'model.pth')


.. GENERATED FROM PYTHON SOURCE LINES 53-54

We can then load the model like this:

.. GENERATED FROM PYTHON SOURCE LINES 54-57

.. code-block:: default


    model = torch.load('model.pth')


.. GENERATED FROM PYTHON SOURCE LINES 58-59

.. note:: This approach uses Python `pickle <https://docs.python.org/3/library/pickle.html>`_ module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

.. GENERATED FROM PYTHON SOURCE LINES 61-64

Related Tutorials
-----------------
`Saving and Loading a General Checkpoint in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html>`_


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)


.. _sphx_glr_download_beginner_basics_saveloadrun_tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: saveloadrun_tutorial.py <saveloadrun_tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: saveloadrun_tutorial.ipynb <saveloadrun_tutorial.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_