Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metaschedule] Add utility API to ease using manual schedules #10876

Merged
merged 9 commits into from
Apr 5, 2022
Prev Previous commit
Next Next commit
refactored test_meta_schedule_tune_relay.py
  • Loading branch information
masahi committed Apr 2, 2022
commit 65733d1a791b874b43492c4675f243be50771078
75 changes: 35 additions & 40 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from tvm.meta_schedule.integration import ApplyHistoryBest
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.tune import (
Parse,
extract_task_from_relay,
tune_extracted_tasks,
tune_relay,
apply_manual_schedules,
)
from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
Expand Down Expand Up @@ -479,51 +479,45 @@ def manual_tir_common(do_tune=False):

params = {"weight": weight_np, "bias": bias_np}

extracted_tasks = extract_task_from_relay(relay_mod, target, params)
if do_tune:
extracted_tasks = extract_task_from_relay(relay_mod, target, params)

# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
extracted_tasks,
# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
extracted_tasks,
)
)
config = ReplayTraceConfig(
num_trials_per_iter=64,
num_trials_total=64,
)
)

with tempfile.TemporaryDirectory() as work_dir:
if do_tune:
config = ReplayTraceConfig(
num_trials_per_iter=64,
num_trials_total=64,
)
with tempfile.TemporaryDirectory() as work_dir:
# postprocs=lambda: [] is important to prevent default post processors from
# tampering with the manual schedule.
database = tune_extracted_tasks(
tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: []
)
else:
database = JSONDatabase(
path_workload=osp.join(work_dir, "database_workload.json"),
path_tuning_record=osp.join(work_dir, "database_tuning_record.json"),
)
else:

def schedule_fn(task, sch):
if "dense" not in task.task_name:
return False

for task in tune_tasks:
mod = Parse._mod(task.dispatched[0])
workload = database.commit_workload(mod)
block = sch.get_block("compute")

sch = tvm.tir.Schedule(mod)
block = sch.get_block("compute")
# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]

# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]
assert "dense_vnni" in schedule_rule

if "dense_vnni" in schedule_rule:
schedule_dense(block, M, False, sch)
schedule_dense(block, M, False, sch)

# [0.0] is for dummy measurement. There is only one tuning record so ApplyHistoryBest
# will always have only one option.
tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
return True

database.commit_tuning_record(tune_rec)
database = apply_manual_schedules(relay_mod, target, params, schedule_fn)

with ApplyHistoryBest(database):
with tvm.transform.PassContext(
Expand Down Expand Up @@ -559,6 +553,7 @@ def test_tune_relay_manual_tir_vnni():
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni)

manual_tir_common(do_tune=False)
return

"""
We can inject and apply a custom TIR scheduling to a TE compute of interest, using
Expand Down Expand Up @@ -593,12 +588,12 @@ def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV):


if __name__ == """__main__""":
test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16")
test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070")
test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=16")
test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070")
test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=16")
test_meta_schedule_tune_relay("bert_base", [1, 64], "nvidia/geforce-rtx-3070")
test_meta_schedule_te2primfunc_argument_order()
test_meta_schedule_relay_lowering()
# test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16")
# test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070")
# test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=16")
# test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070")
# test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=16")
# test_meta_schedule_tune_relay("bert_base", [1, 64], "nvidia/geforce-rtx-3070")
# test_meta_schedule_te2primfunc_argument_order()
# test_meta_schedule_relay_lowering()
test_tune_relay_manual_tir_vnni()