|
227 | 227 | # normalization layers to evaluation mode before running inference. |
228 | 228 | # Failing to do this will yield inconsistent inference results. |
229 | 229 | # |
230 | | -# Export/Load Model in TorchScript Format |
231 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 230 | +# Saving an Exported Program |
| 231 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
232 | 232 | # |
233 | | -# One common way to do inference with a trained model is to use |
234 | | -# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__, an intermediate |
235 | | -# representation of a PyTorch model that can be run in Python as well as in a |
236 | | -# high performance environment like C++. TorchScript is actually the recommended model format |
237 | | -# for scaled inference and deployment. |
| 233 | +# If you are using ``torch.export``, you can save and load your ``ExportedProgram`` using the |
| 234 | +# ``torch.export.save()`` and ``torch.export.load()`` APIs. with the ``.pt2`` file extension: |
238 | 235 | # |
239 | | -# .. note:: |
240 | | -# Using the TorchScript format, you will be able to load the exported model and |
241 | | -# run inference without defining the model class. |
242 | | -# |
243 | | -# **Export:** |
244 | | -# |
245 | | -# .. code:: python |
246 | | -# |
247 | | -# model_scripted = torch.jit.script(model) # Export to TorchScript |
248 | | -# model_scripted.save('model_scripted.pt') # Save |
249 | | -# |
250 | | -# **Load:** |
| 236 | +# .. code-block:: python |
| 237 | +# |
| 238 | +# class SimpleModel(torch.nn.Module): |
| 239 | +# def forward(self, x): |
| 240 | +# return x + 10 |
251 | 241 | # |
252 | | -# .. code:: python |
| 242 | +# # Create a sample input |
| 243 | +# sample_input = torch.randn(5) |
| 244 | +# |
| 245 | +# # Export the model |
| 246 | +# exported_program = torch.export.export(SimpleModel(), sample_input) |
253 | 247 | # |
254 | | -# model = torch.jit.load('model_scripted.pt') |
255 | | -# model.eval() |
| 248 | +# # Save the exported program |
| 249 | +# torch.export.save(exported_program, 'exported_program.pt2') |
256 | 250 | # |
257 | | -# Remember that you must call ``model.eval()`` to set dropout and batch |
258 | | -# normalization layers to evaluation mode before running inference. |
259 | | -# Failing to do this will yield inconsistent inference results. |
| 251 | +# # Load the exported program |
| 252 | +# saved_exported_program = torch.export.load('exported_program.pt2') |
260 | 253 | # |
261 | | -# For more information on TorchScript, feel free to visit the dedicated |
262 | | -# `tutorials <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`__. |
263 | | -# You will get familiar with the tracing conversion and learn how to |
264 | | -# run a TorchScript module in a `C++ environment <https://pytorch.org/tutorials/advanced/cpp_export.html>`__. |
265 | | - |
266 | | - |
267 | 254 |
|
268 | 255 | ###################################################################### |
269 | 256 | # Saving & Loading a General Checkpoint for Inference and/or Resuming Training |
|
0 commit comments