Skip to content

Commit 2cb26f1

Browse files
author
Priya Goyal
committed
Tutorial for Super-resolution using caffe2
1 parent 280662e commit 2cb26f1

File tree

3 files changed

+210
-10
lines changed

3 files changed

+210
-10
lines changed

_static/img/SRResNet.png

719 KB
Loading
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
Transfering a model from PyTorch to Caffe2 and Mobile using ONNX
3+
================================================================
4+
5+
"""
6+
7+
8+
######################################################################
9+
# In this tutorial, we describe how to use ONNX to convert a model defined
10+
# in PyTorch into the ONNX format and then load it into Caffe2. Once in
11+
# Caffe2, we can run the model to double-check it was exported correctly,
12+
# and we then show how to use Caffe2 features such as mobile exporter for
13+
# executing the model on mobile devices.
14+
#
15+
16+
# Some standard imports
17+
import io
18+
import numpy as np
19+
20+
from torch import nn
21+
from torch.autograd import Variable
22+
import torch.utils.model_zoo as model_zoo
23+
import torch.onnx
24+
25+
import onnx
26+
import onnx.backend
27+
import onnx.backend.caffe2
28+
29+
30+
######################################################################
31+
# For this tutorial, we will transfer a super-resolution model as an
32+
# example. First, let's create a SuperResolution model in PyTorch. `This
33+
# model <https://github.com/pytorch/examples/blob/master/super_resolution/model.py>`__
34+
# comes directly from PyTorch's examples without modification:
35+
#
36+
37+
# Super Resolution model definition in PyTorch
38+
import torch.nn as nn
39+
import torch.nn.init as init
40+
41+
42+
class SuperResolutionNet(nn.Module):
43+
def __init__(self, upscale_factor, inplace=False):
44+
super(SuperResolutionNet, self).__init__()
45+
46+
self.relu = nn.ReLU(inplace=inplace)
47+
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
48+
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
49+
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
50+
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
51+
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
52+
53+
self._initialize_weights()
54+
55+
def forward(self, x):
56+
x = self.relu(self.conv1(x))
57+
x = self.relu(self.conv2(x))
58+
x = self.relu(self.conv3(x))
59+
x = self.pixel_shuffle(self.conv4(x))
60+
return x
61+
62+
def _initialize_weights(self):
63+
init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
64+
init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
65+
init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
66+
init.orthogonal(self.conv4.weight)
67+
68+
# Create the super-resolution model by using the above model definition.
69+
torch_model = SuperResolutionNet(upscale_factor=3)
70+
71+
72+
######################################################################
73+
# Ordinarily, you would now train this model; however, for this tutorial,
74+
# we will instead download some pre-trained weights.
75+
#
76+
77+
# Load pretrained model weights
78+
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
79+
batch_size = 1 # just a random number
80+
81+
# Initialize model with the pretrained weights
82+
torch_model.load_state_dict(model_zoo.load_url(model_url))
83+
84+
# set the train mode to false since we will only run the forward pass.
85+
torch_model.train(False)
86+
87+
88+
######################################################################
89+
# Exporting a model in PyTorch works via tracing. To export a model, you
90+
# call the ``torch.onnx._export()`` function. This will execute the model,
91+
# recording a trace of what operators are used to compute the outputs.
92+
# Because ``export`` runs the model, we need provide an input tensor
93+
# ``x``. The values in this tensor are not important; it can be an image
94+
# or a random tensor as long as it is the right size.
95+
#
96+
97+
# Input to the model
98+
x = Variable(torch.randn(batch_size, 1, 224, 224), requires_grad=True)
99+
100+
# Export the model
101+
torch_out = torch.onnx._export(torch_model, # model being run
102+
x, # model input (or a tuple for multiple inputs)
103+
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
104+
export_params=True) # store the trained parameter weights inside the model file
105+
106+
107+
######################################################################
108+
# ``torch_out`` is the output after executing the model. Normally you can
109+
# ignore this output, but here we will use it to verify that the model we
110+
# exported computes the same values when run in Caffe2.
111+
#
112+
# Now let's take the ONNX representation and use it in Caffe2. This part
113+
# can normally be done in a separate process or on another machine, but we
114+
# will continue in the same process so that we can verify that Caffe2 and
115+
# PyTorch are computing the same value for the network:
116+
#
117+
118+
# Load the ONNX GraphProto object. Graph is a standard Python protobuf object
119+
graph = onnx.load("super_resolution.onnx")
120+
121+
# prepare the caffe2 backend for executing the model this converts the ONNX graph into a
122+
# Caffe2 NetDef that can execute it. Other ONNX backends, like one for CNTK will be
123+
# availiable soon.
124+
prepared_backend = onnx.backend.caffe2.prepare(graph)
125+
126+
# run the model in Caffe2
127+
128+
# Construct a map from input names to Tensor data.
129+
# The graph itself contains inputs for all weight parameters, followed by the input image.
130+
# Since the weights are already embedded, we just need to pass the input image.
131+
# last input the graph
132+
W = {graph.input[-1]: x.data.numpy()}
133+
134+
# Run the Caffe2 net:
135+
c2_out = prepared_backend.run(W)[0]
136+
137+
# Verify the numerical correctness upto 3 decimal places
138+
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
139+
140+
141+
######################################################################
142+
# Transfering SRResNet using ONNX
143+
# ===============================
144+
#
145+
146+
147+
######################################################################
148+
# Super-resolution is a way of increasing the resolution of images, videos
149+
# and is widely used in image processing or video editing. For the purpose
150+
# of tutorial, we used a small super-resolution model with a dummy input
151+
# above but using the steps above, we also transferred the SRResNet model
152+
# for super-resolution presented in `this
153+
# paper <https://arxiv.org/pdf/1609.04802.pdf>`__. The model definition
154+
# and a pre-trained model are available
155+
# `here <https://gist.github.com/prigoyal/b245776903efbac00ee89699e001c9bd>`__.
156+
# Below is what SRResNet model input, output looks like. |SRResNet|
157+
#
158+
# .. |SRResNet| image:: /_static/img/SRResNet.png
159+
#
160+
161+
162+
######################################################################
163+
# Running the model on mobile devices
164+
# ===================================
165+
#
166+
167+
168+
######################################################################
169+
# So far we have exported a model from PyTorch and shown how to load it
170+
# and run it in Caffe2. Now that the model is loaded in Caffe2, we can
171+
# convert it into a format suitable for `running on mobile
172+
# devices <https://caffe2.ai/docs/mobile-integration.html>`__.
173+
#
174+
# We will use Caffe2's
175+
# `mobile\_exporter <https://github.com/caffe2/caffe2/blob/master/caffe2/python/predictor/mobile_exporter.py>`__
176+
# to generate the two model protobufs that can run on mobile. The first is
177+
# used to initialize the network with the correct weights, and the second
178+
# actual runs executes the model.
179+
#
180+
181+
# extract the workspace and the graph proto from the internal representation
182+
c2_workspace = prepared_backend.workspace
183+
c2_graph = prepared_backend.predict_net
184+
185+
# Now import the caffe2 mobile exporter
186+
from caffe2.python.predictor import mobile_exporter
187+
188+
# call the Export to get the predict_net, init_net. These nets are needed for running things on mobile
189+
init_net, predict_net = mobile_exporter.Export(c2_workspace, c2_graph, c2_graph.external_input)
190+
191+
192+
######################################################################
193+
# Now, on your ios/Android device, you can use the above protobufs and use
194+
# ``caffe2::Predictor`` (iOS) or Caffe2 instance (Android) for deploying
195+
# them real-time.
196+
# `Here <https://gist.github.com/prigoyal/6bbec5cd121182596848af1c265d7bbe>`__
197+
# is the gist showing steps for running the model on mobile and verifying
198+
# the correctness. Also, for more information, checkout
199+
# `caffe2-android-demo <https://caffe2.ai/docs/AI-Camera-demo-android.html>`__
200+
#

index.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
Welcome to PyTorch Tutorials
22
============================
33

4-
To get started with learning PyTorch, start with our Beginner Tutorials.
5-
The :doc:`60-minute blitz </beginner/deep_learning_60min_blitz>` is the most common
4+
To get started with learning PyTorch, start with our Beginner Tutorials.
5+
The :doc:`60-minute blitz </beginner/deep_learning_60min_blitz>` is the most common
66
starting point, and gives you a quick introduction to PyTorch.
7-
If you like learning by examples, you will like the tutorial
7+
If you like learning by examples, you will like the tutorial
88
:doc:`/beginner/pytorch_with_examples`
99

10-
If you would like to do the tutorials interactively via IPython / Jupyter,
10+
If you would like to do the tutorials interactively via IPython / Jupyter,
1111
each tutorial has a download link for a Jupyter Notebook and Python source code.
1212

13-
We also provide a lot of high-quality examples covering image classification,
14-
unsupervised learning, reinforcement learning, machine translation and
13+
We also provide a lot of high-quality examples covering image classification,
14+
unsupervised learning, reinforcement learning, machine translation and
1515
many other applications at https://github.com/pytorch/examples/
1616

17-
You can find reference documentation for PyTorch's API and layers at
17+
You can find reference documentation for PyTorch's API and layers at
1818
http://docs.pytorch.org or via inline help.
1919
If you would like the tutorials section improved, please open a github issue
2020
here with your feedback: https://github.com/pytorch/tutorials
2121

2222
Beginner Tutorials
2323
------------------
24-
24+
2525
.. customgalleryitem::
2626
:figure: /_static/img/thumbnails/pytorch-logo-flat.png
2727
:tooltip: Understand PyTorch’s Tensor library and neural networks at a high level.
@@ -100,7 +100,7 @@ Reinforcement Learning
100100
:includehidden:
101101
:hidden:
102102
:caption: Intermediate Tutorials
103-
103+
104104
intermediate/char_rnn_classification_tutorial
105105
intermediate/char_rnn_generation_tutorial
106106
intermediate/seq2seq_translation_tutorial
@@ -132,5 +132,5 @@ Advanced Tutorials
132132

133133
advanced/neural_style_tutorial
134134
advanced/numpy_extensions_tutorial
135+
advanced/super_resolution_with_caffe2
135136
advanced/c_extension
136-

0 commit comments

Comments
 (0)