Skip to content

Commit e6a1b1b

Browse files
ganlerylc
authored andcommitted
[Tutorial][Executor] Fix the usage of executors in tutorials (apache#8586)
* fix: executor usage for keras tutorial * fix: executor usage for onnx tutorial * [Tutorial][Executor] Fix executors in tutorials
1 parent bcc3469 commit e6a1b1b

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

tutorials/dev/bring_your_own_datatypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,9 @@ def get_cat_image():
257257
######################################################################
258258
# It's easy to execute MobileNet with native TVM:
259259

260+
ex = tvm.relay.create_executor("graph", mod=module, params=params)
260261
input = get_cat_image()
261-
result = tvm.relay.create_executor("graph", mod=module).evaluate()(input, **params).numpy()
262+
result = ex.evaluate()(input).numpy()
262263
# print first 10 elements
263264
print(result.flatten()[:10])
264265

tutorials/frontend/from_keras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@
103103
# due to a latent bug. Note that the pass context only has an effect within
104104
# evaluate() and is not captured by create_executor().
105105
with tvm.transform.PassContext(opt_level=0):
106-
model = relay.build_module.create_executor("graph", mod, dev, target).evaluate()
106+
model = relay.build_module.create_executor("graph", mod, dev, target, params).evaluate()
107107

108108

109109
######################################################################
110110
# Execute on TVM
111111
# ---------------
112112
dtype = "float32"
113-
tvm_out = model(tvm.nd.array(data.astype(dtype)), **params)
113+
tvm_out = model(tvm.nd.array(data.astype(dtype)))
114114
top1_tvm = np.argmax(tvm_out.numpy()[0])
115115

116116
#####################################################################

tutorials/frontend/from_onnx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@
9292
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
9393

9494
with tvm.transform.PassContext(opt_level=1):
95-
compiled = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target).evaluate()
95+
executor = relay.build_module.create_executor(
96+
"graph", mod, tvm.cpu(0), target, params
97+
).evaluate()
9698

9799
######################################################################
98100
# Execute on TVM
99101
# ---------------------------------------------
100102
dtype = "float32"
101-
tvm_output = compiled(tvm.nd.array(x.astype(dtype)), **params).numpy()
103+
tvm_output = executor(tvm.nd.array(x.astype(dtype))).numpy()
102104

103105
######################################################################
104106
# Display results

0 commit comments

Comments
 (0)