File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed
python/tvm/auto_scheduler Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change 30
30
from .compute_dag import ComputeDAG , LayoutRewriteOption
31
31
from .cost_model import XGBModel
32
32
from .search_policy import SketchPolicy
33
- from .workload_registry import register_workload_tensors
33
+ from .workload_registry import WORKLOAD_FUNC_REGISTRY , register_workload_tensors
34
34
from . import _ffi_api
35
35
36
36
@@ -335,11 +335,12 @@ def __setstate__(self, state):
335
335
except Exception : # pylint: disable=broad-except
336
336
raise RuntimeError ("Invalid workload key %s" % state ["workload_key" ])
337
337
338
- # The workload from a compute DAG does not have arguments and is not registered
339
- # by default so we register it here. If the workload has already been registered,
340
- # the later registration overrides the prvious one.
341
- if len (workload ) == 1 :
342
- register_workload_tensors (workload [0 ], state ["compute_dag" ].tensors )
338
+ # workload[0] is either the compute function name or the ComputeDAG hash.
339
+ # The compute functions are already registered when importing TVM, so here
340
+ # we only register the ComputeDAG workloads. If the same workload has
341
+ # already been registered, the later registration overrides the prvious one.
342
+ if workload [0 ] not in WORKLOAD_FUNC_REGISTRY :
343
+ register_workload_tensors (state ["workload_key" ], state ["compute_dag" ].tensors )
343
344
344
345
self .__init_handle_by_constructor__ (
345
346
_ffi_api .SearchTask ,
Original file line number Diff line number Diff line change @@ -121,7 +121,7 @@ def test_stage_order():
121
121
)
122
122
123
123
task2 = pickle .loads (pickle .dumps (task ))
124
- assert "test-key" in auto_scheduler .workload_registry .WORKLOAD_FUNC_REGISTRY
124
+ assert '[ "test-key"]' in auto_scheduler .workload_registry .WORKLOAD_FUNC_REGISTRY
125
125
assert str (task .compute_dag .get_init_state ()) == str (task2 .compute_dag .get_init_state ())
126
126
assert len (task .compute_dag .get_init_state ().stage_ops ) == len (
127
127
task2 .compute_dag .get_init_state ().stage_ops
You can’t perform that action at this time.
0 commit comments