Skip to content

Commit 30c110c

Browse files
authored
[Bugfix][AutoScheduler] Fail to register ComputeDAG when deserializing tasks (apache#7395)
* [Bugfix][AutoScheduler] Fail to register ComputeDAG when deserialize tasks * fix test * trigger ci
1 parent 1de98be commit 30c110c

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

python/tvm/auto_scheduler/search_task.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .compute_dag import ComputeDAG, LayoutRewriteOption
3131
from .cost_model import XGBModel
3232
from .search_policy import SketchPolicy
33-
from .workload_registry import register_workload_tensors
33+
from .workload_registry import WORKLOAD_FUNC_REGISTRY, register_workload_tensors
3434
from . import _ffi_api
3535

3636

@@ -335,11 +335,12 @@ def __setstate__(self, state):
335335
except Exception: # pylint: disable=broad-except
336336
raise RuntimeError("Invalid workload key %s" % state["workload_key"])
337337

338-
# The workload from a compute DAG does not have arguments and is not registered
339-
# by default so we register it here. If the workload has already been registered,
340-
# the later registration overrides the prvious one.
341-
if len(workload) == 1:
342-
register_workload_tensors(workload[0], state["compute_dag"].tensors)
338+
# workload[0] is either the compute function name or the ComputeDAG hash.
339+
# The compute functions are already registered when importing TVM, so here
340+
# we only register the ComputeDAG workloads. If the same workload has
341+
# already been registered, the later registration overrides the prvious one.
342+
if workload[0] not in WORKLOAD_FUNC_REGISTRY:
343+
register_workload_tensors(state["workload_key"], state["compute_dag"].tensors)
343344

344345
self.__init_handle_by_constructor__(
345346
_ffi_api.SearchTask,

tests/python/unittest/test_auto_scheduler_compute_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_stage_order():
121121
)
122122

123123
task2 = pickle.loads(pickle.dumps(task))
124-
assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY
124+
assert '["test-key"]' in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY
125125
assert str(task.compute_dag.get_init_state()) == str(task2.compute_dag.get_init_state())
126126
assert len(task.compute_dag.get_init_state().stage_ops) == len(
127127
task2.compute_dag.get_init_state().stage_ops

0 commit comments

Comments
 (0)