Skip to content

Commit 4ebd5d2

Browse files
AndrewZhaoLuoylc
authored andcommitted
[Autoscheduler] Configurable workload keys (apache#8862)
* change workload keys * remove binary string comparison * append the tuple not every integer * clean up * lint * dump workload keys to dags * fix things * change some strings * misc fixes, add tests * jostle ci
1 parent 1c85a08 commit 4ebd5d2

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

python/tvm/auto_scheduler/compute_dag.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,20 +222,27 @@ def rewrite_layout_from_state(self, state):
222222

223223
def workload_key(self):
224224
"""Return the workload key of this compute DAG.
225-
The workload key is a JSON string from a tuple of (hash-key, tensor shapes...)
225+
The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...)
226226
227227
Returns
228228
-------
229229
key: str
230230
The workload key of this compute DAG
231231
"""
232232
str_dag = _ffi_api.ComputeDAGPrintDAG(self, True)
233-
str_dag = str_dag.encode(encoding="utf-8")
234-
hash_key = hashlib.md5(str_dag).hexdigest()
233+
hash_func = tvm._ffi.get_global_func(
234+
"auto_scheduler.compute_dag.hash_func", allow_missing=True
235+
)
236+
237+
if hash_func is None:
238+
str_dag = str_dag.encode("utf-8")
239+
hash_key = hashlib.md5(str_dag).hexdigest()
240+
else:
241+
hash_key = hash_func(str_dag)
235242

236243
io_shapes = []
237244
for tensor in self.tensors:
238-
io_shapes += get_const_tuple(tensor.shape)
245+
io_shapes.append(get_const_tuple(tensor.shape))
239246
return json.dumps([hash_key] + io_shapes)
240247

241248
def __str__(self):

python/tvm/auto_scheduler/relay_integration.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
2. Provide auto-scheduling for all TOPI compute functions
2323
"""
2424

25+
import json
2526
import logging
2627
import threading
2728
from copy import deepcopy
@@ -30,11 +31,10 @@
3031
from tvm import autotvm, transform
3132
from tvm.ir.transform import PassContext
3233
from tvm.runtime import convert_to_object
33-
34+
from tvm.target import Target
3435
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
3536
from tvm.tir import Reduce
3637
from tvm.tir import expr as _expr
37-
from tvm.target import Target
3838

3939
from . import _ffi_api
4040
from .compute_dag import ComputeDAG, LayoutRewriteOption
@@ -97,6 +97,7 @@ def extract_tasks(
9797
target_host=None,
9898
hardware_params=None,
9999
include_simple_tasks=False,
100+
dump_workload_to_dag_log=None,
100101
opt_level=3,
101102
):
102103
"""Extract tuning tasks from a relay program.
@@ -115,6 +116,8 @@ def extract_tasks(
115116
Hardware parameters used for the search tasks
116117
include_simple_tasks: bool
117118
Whether to extract simple tasks that do not include complicated ops.
119+
dump_workload_to_dag_log: Optional[str]
120+
A file to dump an association between the workload keys and the actual DAG
118121
opt_level : Optional[int]
119122
The optimization level of the task extractions.
120123
@@ -170,6 +173,10 @@ def extract_tasks(
170173
)
171174
weights.append(weight)
172175

176+
if dump_workload_to_dag_log is not None:
177+
with open(dump_workload_to_dag_log, "w") as f:
178+
json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f)
179+
173180
return tasks, weights
174181

175182

tests/python/relay/test_auto_scheduler_task_extraction.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Test task extraction for auto-scheduler"""
18-
import pytest
18+
import json
19+
import tempfile
1920

21+
import pytest
2022
import tvm.relay.testing
2123
import tvm.testing
24+
from tvm import _ffi as _ffi_api
2225
from tvm import auto_scheduler, relay
2326

2427

@@ -248,5 +251,44 @@ def verify_task_extraction(func_name, expected_task, include_simple_tasks=False)
248251
verify_task_extraction(*params)
249252

250253

254+
def test_dump_workload_to_dag_extract_tasks():
255+
mod, _ = get_network("mobilenet", layout="NHWC")
256+
with tempfile.NamedTemporaryFile() as f:
257+
tasks, _ = auto_scheduler.extract_tasks(
258+
mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name
259+
)
260+
expected = {task.workload_key: str(task.compute_dag) for task in tasks}
261+
actual = json.load(f)
262+
assert expected == actual
263+
264+
265+
def test_custom_hash_func_extract_tasks():
266+
@_ffi_api.register_func("auto_scheduler.compute_dag.hash_func")
267+
def counting_unique_hash(str_dag):
268+
ret = counting_unique_hash.i
269+
counting_unique_hash.i += 1
270+
return ret
271+
272+
counting_unique_hash.i = 0
273+
274+
mod, _ = get_network("mobilenet", layout="NHWC")
275+
tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True)
276+
277+
hash_values = []
278+
for task in tasks:
279+
# task.workload_key should look like
280+
# [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash
281+
# Extract the hash and keep track of every hash
282+
hash_value = int(task.workload_key[1:].split(",")[0])
283+
hash_values.append(hash_value)
284+
285+
# All values are unique, and we know the min and max
286+
# This is a sufficient condition to know that hashes in hash_values are an increasing list
287+
# of hashes up to counting_unique_hash.i - 1
288+
assert len(hash_values) == len(set(hash_values))
289+
assert min(hash_values) == 0
290+
assert max(hash_values) == counting_unique_hash.i - 1
291+
292+
251293
if __name__ == "__main__":
252294
pytest.main([__file__])

0 commit comments

Comments
 (0)