diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 93467e27d0e7..3427709d819a 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -61,7 +61,7 @@ def __init__(self, compute_or_sche): if isinstance(compute_or_sche, str): compute = workload_key_to_tensors(compute_or_sche) sche = None - elif isinstance(compute_or_sche, list): + elif isinstance(compute_or_sche, (list, tvm.ir.container.Array)): for item in compute_or_sche: if not isinstance(item, tvm.te.Tensor): raise ValueError( diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 283d8bf7db45..6864bcce66e3 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -22,6 +22,7 @@ 2. Provide auto-scheduling for all TOPI compute functions """ +import logging import threading import tvm @@ -32,6 +33,8 @@ from .search_task import SearchTask from .workload_registry import register_workload_tensors +logger = logging.getLogger("auto_scheduler") + def call_all_topi_funcs(mod, params, target): """Call all TOPI compute to extract auto_scheduler tasks in a Relay program""" @@ -218,16 +221,19 @@ def auto_schedule_topi(outs, has_complex_op): from tvm import relay io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_workload_tensors(io_tensors) - if key is None: # skip this compute if failed to register the workload + try: + dag = ComputeDAG(io_tensors) + except tvm.error.TVMError as err: + logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err)) return None + key = register_workload_tensors(dag.hash_key(), io_tensors) + # only enable layout rewrite for cpu backend enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys env = TracingEnvironment.current if env is None: # in the final build mode - dag = ComputeDAG(io_tensors) state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag) if state is None: return None diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 7c5021b3f9b7..f2dadccbf891 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -17,10 +17,13 @@ """ The definiton of SearchTask """ +import json + import tvm._ffi from tvm.runtime import Object from . import _ffi_api +from .workload_registry import register_workload_tensors @tvm._ffi.register_object("auto_scheduler.SearchTask") @@ -63,6 +66,19 @@ def __getstate__(self): def __setstate__(self, state): self.dag = state["dag"] self.workload_key = state["workload_key"] + + # Register the workload if needed + try: + workload = json.loads(self.workload_key) + except Exception: # pylint: disable=broad-except + raise RuntimeError("Invalid workload key %s" % self.workload_key) + + # The workload from a compute DAG does not have arguments and is not registered + # by default so we register it here. If the workload has already been registered, + # the later registration overrides the prvious one. + if len(workload) == 1: + register_workload_tensors(workload[0], self.dag.tensors) + self.target = state["target"] self.target_host = state["target_host"] self.hardware_params = state["hardware_params"] diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 6a4809b1796c..9a7c15c877aa 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -64,7 +64,7 @@ def register_workload(func_name, f=None, override=False): f : Optional[Function] The generation function to be registered. override : boolean = False - Whether override existing entry. + Whether to override existing entry. Examples -------- @@ -98,30 +98,26 @@ def register(myf): return register -def register_workload_tensors(tensors): - """Register a workload by provding input/output tensors +def register_workload_tensors(func_name, tensors, override=True): + """Register a workload by provding input/output tensors. Since this function is used + when extracting/deserializing tasks, it expects duplicated registrations by default. Parameters ---------- + func_name: str + The function name or the hash key of the compute DAG. tensors: List[Tensor] The input/output tensors of a compute DAG + override : boolean = True + Whether to override existing entry. Returns ------- - key: Optional[str] - The workload key, or None if failed to create a compute DAG. + key: str + The serialized JSON string as the workload key. """ - # pylint: disable=import-outside-toplevel - from .compute_dag import ComputeDAG - - try: - key = ComputeDAG(tensors).hash_key() - except tvm.error.TVMError as err: - logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err)) - return None - - WORKLOAD_FUNC_REGISTRY[key] = tensors - return json.dumps((key,)) + register_workload(func_name, override=override)(tensors) + return json.dumps((func_name,)) def make_workload_key(func, args): diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index e7774753796c..caf3c9d888b6 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -16,6 +16,7 @@ # under the License. """Test ComputeDAG (replay, infer bound)""" +import json import pickle import tvm @@ -120,11 +121,13 @@ def test_stage_order(): # Serialize and deserialize the search task. task = auto_scheduler.SearchTask( dag, - "test1", + json.dumps(("test-key",)), tvm.target.Target("llvm"), hardware_params=auto_scheduler.HardwareParams(100000, 16, 64), ) + task2 = pickle.loads(pickle.dumps(task)) + assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state()) assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops) assert task.workload_key == task2.workload_key