Skip to content

Commit 055a90e

Browse files
houseroadchsasank
authored andcommitted
Update the Super Resolution example, so it is compatible with latest ONNX (pytorch#158)
1 parent 62795e5 commit 055a90e

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

advanced_source/super_resolution_with_caffe2.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,28 +131,29 @@ def _initialize_weights(self):
131131
import onnx
132132
import onnx_caffe2.backend
133133

134-
# Load the ONNX GraphProto object. Graph is a standard Python protobuf object
135-
graph = onnx.load("super_resolution.onnx")
134+
# Load the ONNX ModelProto object. model is a standard Python protobuf object
135+
model = onnx.load("super_resolution.onnx")
136136

137-
# prepare the caffe2 backend for executing the model this converts the ONNX graph into a
137+
# prepare the caffe2 backend for executing the model this converts the ONNX model into a
138138
# Caffe2 NetDef that can execute it. Other ONNX backends, like one for CNTK will be
139139
# availiable soon.
140-
prepared_backend = onnx_caffe2.backend.prepare(graph)
140+
prepared_backend = onnx_caffe2.backend.prepare(model)
141141

142142
# run the model in Caffe2
143143

144144
# Construct a map from input names to Tensor data.
145-
# The graph itself contains inputs for all weight parameters, followed by the input image.
145+
# The graph of the model itself contains inputs for all weight parameters, after the input image.
146146
# Since the weights are already embedded, we just need to pass the input image.
147-
# last input the grap
148-
W = {graph.input[-1]: x.data.numpy()}
147+
# Set the first input.
148+
W = {model.graph.input[0].name: x.data.numpy()}
149149

150150
# Run the Caffe2 net:
151151
c2_out = prepared_backend.run(W)[0]
152152

153153
# Verify the numerical correctness upto 3 decimal places
154154
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
155155

156+
print("Exported model has been executed on Caffe2 backend, and the result looks good!")
156157

157158
######################################################################
158159
# We should see that the output of PyTorch and Caffe2 runs match
@@ -202,15 +203,15 @@ def _initialize_weights(self):
202203
# super-resolution model for the rest of this tutorial.
203204
#
204205

205-
# extract the workspace and the graph proto from the internal representation
206+
# extract the workspace and the model proto from the internal representation
206207
c2_workspace = prepared_backend.workspace
207-
c2_graph = prepared_backend.predict_net
208+
c2_model = prepared_backend.predict_net
208209

209210
# Now import the caffe2 mobile exporter
210211
from caffe2.python.predictor import mobile_exporter
211212

212213
# call the Export to get the predict_net, init_net. These nets are needed for running things on mobile
213-
init_net, predict_net = mobile_exporter.Export(c2_workspace, c2_graph, c2_graph.external_input)
214+
init_net, predict_net = mobile_exporter.Export(c2_workspace, c2_model, c2_model.external_input)
214215

215216
# Let's also save the init_net and predict_net to a file that we will later use for running them on mobile
216217
with open('init_net.pb', "wb") as fopen:

0 commit comments

Comments
 (0)