Skip to content

Commit

Permalink
[AutoScheduler] Register workload when deserializing tasks (#6927)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Register workload when deserializing tasks

* fix name

* format

* merge

* fix test

* more checks
  • Loading branch information
comaniac authored Nov 20, 2020
1 parent 712663e commit 56e226a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 21 deletions.
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
compute = workload_key_to_tensors(compute_or_sche)
sche = None
elif isinstance(compute_or_sche, list):
elif isinstance(compute_or_sche, (list, tvm.ir.container.Array)):
for item in compute_or_sche:
if not isinstance(item, tvm.te.Tensor):
raise ValueError(
Expand Down
12 changes: 9 additions & 3 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
2. Provide auto-scheduling for all TOPI compute functions
"""

import logging
import threading

import tvm
Expand All @@ -32,6 +33,8 @@
from .search_task import SearchTask
from .workload_registry import register_workload_tensors

logger = logging.getLogger("auto_scheduler")


def call_all_topi_funcs(mod, params, target):
"""Call all TOPI compute to extract auto_scheduler tasks in a Relay program"""
Expand Down Expand Up @@ -218,16 +221,19 @@ def auto_schedule_topi(outs, has_complex_op):
from tvm import relay

io_tensors, has_layout_free = traverse_to_get_io_tensors(outs)
key = register_workload_tensors(io_tensors)
if key is None: # skip this compute if failed to register the workload
try:
dag = ComputeDAG(io_tensors)
except tvm.error.TVMError as err:
logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err))
return None

key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu backend
enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys

env = TracingEnvironment.current
if env is None: # in the final build mode
dag = ComputeDAG(io_tensors)
state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag)
if state is None:
return None
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

""" The definiton of SearchTask """

import json

import tvm._ffi
from tvm.runtime import Object

from . import _ffi_api
from .workload_registry import register_workload_tensors


@tvm._ffi.register_object("auto_scheduler.SearchTask")
Expand Down Expand Up @@ -63,6 +66,19 @@ def __getstate__(self):
def __setstate__(self, state):
self.dag = state["dag"]
self.workload_key = state["workload_key"]

# Register the workload if needed
try:
workload = json.loads(self.workload_key)
except Exception: # pylint: disable=broad-except
raise RuntimeError("Invalid workload key %s" % self.workload_key)

# The workload from a compute DAG does not have arguments and is not registered
# by default so we register it here. If the workload has already been registered,
# the later registration overrides the prvious one.
if len(workload) == 1:
register_workload_tensors(workload[0], self.dag.tensors)

self.target = state["target"]
self.target_host = state["target_host"]
self.hardware_params = state["hardware_params"]
Expand Down
28 changes: 12 additions & 16 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def register_workload(func_name, f=None, override=False):
f : Optional[Function]
The generation function to be registered.
override : boolean = False
Whether override existing entry.
Whether to override existing entry.
Examples
--------
Expand Down Expand Up @@ -98,30 +98,26 @@ def register(myf):
return register


def register_workload_tensors(tensors):
"""Register a workload by provding input/output tensors
def register_workload_tensors(func_name, tensors, override=True):
"""Register a workload by provding input/output tensors. Since this function is used
when extracting/deserializing tasks, it expects duplicated registrations by default.
Parameters
----------
func_name: str
The function name or the hash key of the compute DAG.
tensors: List[Tensor]
The input/output tensors of a compute DAG
override : boolean = True
Whether to override existing entry.
Returns
-------
key: Optional[str]
The workload key, or None if failed to create a compute DAG.
key: str
The serialized JSON string as the workload key.
"""
# pylint: disable=import-outside-toplevel
from .compute_dag import ComputeDAG

try:
key = ComputeDAG(tensors).hash_key()
except tvm.error.TVMError as err:
logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err))
return None

WORKLOAD_FUNC_REGISTRY[key] = tensors
return json.dumps((key,))
register_workload(func_name, override=override)(tensors)
return json.dumps((func_name,))


def make_workload_key(func, args):
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_auto_scheduler_compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

"""Test ComputeDAG (replay, infer bound)"""
import json
import pickle

import tvm
Expand Down Expand Up @@ -120,11 +121,13 @@ def test_stage_order():
# Serialize and deserialize the search task.
task = auto_scheduler.SearchTask(
dag,
"test1",
json.dumps(("test-key",)),
tvm.target.Target("llvm"),
hardware_params=auto_scheduler.HardwareParams(100000, 16, 64),
)

task2 = pickle.loads(pickle.dumps(task))
assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY
assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state())
assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops)
assert task.workload_key == task2.workload_key
Expand Down

0 comments on commit 56e226a

Please sign in to comment.