diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index b39aba227a88..68f53125c7ae 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -283,10 +283,13 @@ def auto_schedule_topi(outs): key = register_workload_tensors(dag.workload_key(), io_tensors) target = tvm.target.Target.current() + dispatch_ctx = DispatchContext.current + state = dispatch_ctx.query(target, key, has_complex_op, dag) + schedule = None + env = TracingEnvironment.current if env is None: # in the final build mode - state = DispatchContext.current.query(target, key, has_complex_op, dag) if state is None: return None @@ -303,8 +306,6 @@ def auto_schedule_topi(outs): LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE and has_layout_free ): - dispatch_ctx = DispatchContext.current - state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: return None @@ -316,7 +317,7 @@ def auto_schedule_topi(outs): else: raise ValueError("Invalid tracing mode: " + env.tracing_mode) - return None + return schedule def tensor_no_check_call(self, *indices):