Skip to content

Commit

Permalink
[AutoScheduler] Do not return naive schedule in tracing mode (#7226)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Do not return naive schedule in tracing mode

* lint

* fix
  • Loading branch information
comaniac authored Jan 8, 2021
1 parent 29da763 commit 54c995d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import threading

import tvm
from tvm import autotvm, te, transform
from tvm import autotvm, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
Expand Down Expand Up @@ -267,7 +267,7 @@ def auto_schedule_topi(outs):
-------
sch: Optional[te.Schedule]
A tuned schedule or none (if not tuned) in the final build mode;
An initial schdule in the tracing mode.
None in the tracing mode so that the fallback topi schedule will be used.
"""
# pylint: disable=import-outside-toplevel

Expand All @@ -282,7 +282,6 @@ def auto_schedule_topi(outs):
return None

key = register_workload_tensors(dag.hash_key(), io_tensors)

target = tvm.target.Target.current()

env = TracingEnvironment.current
Expand All @@ -293,11 +292,12 @@ def auto_schedule_topi(outs):
return None

schedule, _ = dag.apply_steps_from_state(state)
elif env.tracing_mode in [TracingMode.EXTRACT_TASK, TracingMode.EXTRACT_COMPLEX_TASK_ONLY]:
return schedule

if env.tracing_mode in [TracingMode.EXTRACT_TASK, TracingMode.EXTRACT_COMPLEX_TASK_ONLY]:
# in the task extraction mode
if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
env.add_workload_key(key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
if (
Expand All @@ -315,11 +315,10 @@ def auto_schedule_topi(outs):
new_key = json.dumps((new_dag.hash_key(),))
if new_key != key:
dispatch_ctx.update(target, new_key, state)
return te.create_schedule([x.op for x in outs])
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)

return schedule
return None


def tensor_no_check_call(self, *indices):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def naive_schedule(_, outs, target):
if "gpu" in target.keys:
# For GPU, we at least need thread binding to make a valid schedule.
# So the naive schedule cannot be compiled.
raise RuntimeError(
logger.debug(
"Cannot compile for GPU targets if no tuned schedule is found. "
"Please see the warning messages above for more information about the failed workloads."
)
Expand Down

0 comments on commit 54c995d

Please sign in to comment.