.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/basics/saveloadrun_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here <sphx_glr_download_beginner_basics_saveloadrun_tutorial.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_saveloadrun_tutorial.py: `Learn the Basics <intro.html>`_ || `Quickstart <quickstart_tutorial.html>`_ || `Tensors <tensorqs_tutorial.html>`_ || `Datasets & DataLoaders <data_tutorial.html>`_ || `Transforms <transforms_tutorial.html>`_ || `Build Model <buildmodel_tutorial.html>`_ || `Autograd <autogradqs_tutorial.html>`_ || `Optimization <optimization_tutorial.html>`_ || **Save & Load Model** Save and Load the Model ============================ In this section we will look at how to persist model state with saving, loading and running model predictions. .. GENERATED FROM PYTHON SOURCE LINES 17-22 .. code-block:: default import torch import torchvision.models as models .. GENERATED FROM PYTHON SOURCE LINES 23-28 Saving and Loading Model Weights -------------------------------- PyTorch models store the learned parameters in an internal state dictionary, called ``state_dict``. These can be persisted via the ``torch.save`` method: .. GENERATED FROM PYTHON SOURCE LINES 28-32 .. code-block:: default model = models.vgg16(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth') .. GENERATED FROM PYTHON SOURCE LINES 33-35 To load model weights, you need to create an instance of the same model first, and then load the parameters using ``load_state_dict()`` method. .. GENERATED FROM PYTHON SOURCE LINES 35-40 .. code-block:: default model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load('model_weights.pth')) model.eval() .. GENERATED FROM PYTHON SOURCE LINES 41-42 .. note:: be sure to call ``model.eval()`` method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results. .. GENERATED FROM PYTHON SOURCE LINES 44-49 Saving and Loading Models with Shapes ------------------------------------- When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass ``model`` (and not ``model.state_dict()``) to the saving function: .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: default torch.save(model, 'model.pth') .. GENERATED FROM PYTHON SOURCE LINES 53-54 We can then load the model like this: .. GENERATED FROM PYTHON SOURCE LINES 54-57 .. code-block:: default model = torch.load('model.pth') .. GENERATED FROM PYTHON SOURCE LINES 58-59 .. note:: This approach uses Python `pickle <https://docs.python.org/3/library/pickle.html>`_ module when serializing the model, thus it relies on the actual class definition to be available when loading the model. .. GENERATED FROM PYTHON SOURCE LINES 61-64 Related Tutorials ----------------- `Saving and Loading a General Checkpoint in PyTorch <https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html>`_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_beginner_basics_saveloadrun_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: saveloadrun_tutorial.py <saveloadrun_tutorial.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: saveloadrun_tutorial.ipynb <saveloadrun_tutorial.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_