Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metaschedule] Add utility API to ease using manual schedules #10876

Merged
merged 9 commits into from
Apr 5, 2022
9 changes: 8 additions & 1 deletion python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from .utils import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture, DummyMutator
from .utils import (
DummyDatabase,
DummyBuilder,
DummyRunner,
DummyRunnerFuture,
DummyMutator,
apply_fixed_schedules,
)
53 changes: 51 additions & 2 deletions python/tvm/meta_schedule/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilitiy functions in meta schedule"""
from typing import List, Optional
import random
from typing import List, Optional, Callable, Dict, Union

import tvm

from tvm.relay import Function as RelayFunc
from tvm.tir import Schedule
from tvm.target import Target
from tvm.runtime import NDArray
from tvm.meta_schedule import TuneContext # pylint: disable=unused-import
from tvm.meta_schedule.utils import derived_object
from tvm.meta_schedule.mutator.mutator import PyMutator
Expand All @@ -32,6 +35,9 @@
PyRunnerFuture,
PyRunner,
)
from tvm.meta_schedule.tune import Parse, extract_task_from_relay
from tvm.meta_schedule.integration import ExtractedTask

from tvm.ir import IRModule
from tvm.tir.schedule import Trace

Expand Down Expand Up @@ -110,3 +116,46 @@ def initialize_with_tune_context(self, context: "TuneContext") -> None:

def apply(self, trace: Trace, _) -> Optional[Trace]:
return Trace(trace.insts, {})


def apply_fixed_schedules(
relay_mod: Union[RelayFunc, IRModule],
target: Union[str, Target],
params: Optional[Dict[str, NDArray]],
schedule_fn: Callable[[ExtractedTask, Schedule], bool],
):
"""Apply fixed schedules (manually written, without any tunable knobs) as specified by
schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.
Parameters
----------
mod : Union[RelayFunc, IRModule]
The Relay module to apply fixed schedules.
target : Union[str, Target]
The target used to extract tasks.
params : Optional[Dict[str, tvm.runtime.NDArray]]
The associated parameters of the module.
schedule_fn : Callable[[ExtractedTask, Schedule], bool]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about we change it to Callable[[ExtractedTask], Schedule], i.e., for input the dispatched IRModule is inside of task, and we can return a schedule if it matches our rule, otherwise return None. I think it might be better to avoid confusion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's possible, but that will force users to write

mod = Parse._mod(task.dispatched[0])
sch = Schedule(mod)

in every schedule_fn callback. I think this boilerplate is non-trivial (users shouldn't care about dispatched or Parse stuff).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That make sense, and I wonder if using the task name and schedule would suffice, if we don't want users to care about details inside of the extracted task.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some cases where having the relay mod is required. For example, I want to be able to skip tasks based on the output dtype of the compute, which can be retrieved from the relay mod.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks for the explaination.

A callable that is applied for each extracted task and the corresponding default schedule.
Returns True if the given schedule should be committed to the database, False otherwise.
Returns
-------
database : Database
The database containing dummy tuning records for manually scheduled traces.
"""
target = Target(target) if isinstance(target, str) else target
extracted_tasks = extract_task_from_relay(relay_mod, target, params)

database = DummyDatabase()

for task in extracted_tasks:
mod = Parse._mod(task.dispatched[0])
sch = Schedule(mod)

if schedule_fn(task, sch):
workload = database.commit_workload(mod)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be sch.mod given it has gone through a schedule function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The purpose of this workload commit is to match against the unmodified mod during ApplyHistoryBest. So we want to commit the original mod as is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct, thank you.

tune_rec = TuningRecord(sch.trace, [0.0], workload, target, [])
database.commit_tuning_record(tune_rec)

return database
39 changes: 8 additions & 31 deletions tests/python/unittest/test_meta_schedule_multi_anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tempfile

import numpy as np

import tvm
import tvm.testing
from tvm import relay
from tvm.meta_schedule.tune import Parse, extract_task_from_relay
from tvm.meta_schedule.database import TuningRecord, JSONDatabase
from tvm.meta_schedule.testing import apply_fixed_schedules
from tvm.meta_schedule.integration import ApplyHistoryBest


Expand Down Expand Up @@ -72,39 +68,20 @@ def test_dense_dense():

# print(relay.transform.InferType()(relay_mod))

target = "llvm"

data_np = np.random.randn(*data_shape).astype("float32")
weight1_np = np.random.randn(*weight_shape).astype("float32")
weight2_np = np.random.randn(*weight_shape).astype("float32")

target = "llvm"
params = {"weight1": weight1_np, "weight2": weight2_np}

extracted_tasks = extract_task_from_relay(relay_mod, target, params)

assert len(extracted_tasks) == 1

task = extracted_tasks[0]

mod = Parse._mod(task.dispatched[0])

with tempfile.TemporaryDirectory() as work_dir:
database = JSONDatabase(
path_workload=os.path.join(work_dir, "database_workload.json"),
path_tuning_record=os.path.join(work_dir, "database_tuning_record.json"),
)

workload = database.commit_workload(mod)

sch = tvm.tir.Schedule(mod)

schedule_dense_dense(sch)

# print(sch.mod.script())

tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
def schedule_fn(task, sch):
if "nn_dense_nn_dense" in task.task_name:
schedule_dense_dense(sch)
return True
return False

database.commit_tuning_record(tune_rec)
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)

with ApplyHistoryBest(database):
with tvm.transform.PassContext(
Expand Down
60 changes: 27 additions & 33 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload
from tvm.meta_schedule.integration import ApplyHistoryBest
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing import apply_fixed_schedules
from tvm.meta_schedule.tune import (
Parse,
extract_task_from_relay,
tune_extracted_tasks,
tune_relay,
Expand Down Expand Up @@ -480,52 +480,46 @@ def manual_tir_common(do_tune=False):

params = {"weight": weight_np, "bias": bias_np}

extracted_tasks = extract_task_from_relay(relay_mod, target, params)
if do_tune:
extracted_tasks = extract_task_from_relay(relay_mod, target, params)

# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
extracted_tasks,
# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
extracted_tasks,
)
)
config = ReplayTraceConfig(
num_trials_per_iter=64,
max_trials_per_task=64,
max_trials_global=20000,
)
)

with tempfile.TemporaryDirectory() as work_dir:
if do_tune:
config = ReplayTraceConfig(
num_trials_per_iter=64,
max_trials_per_task=64,
max_trials_global=20000,
)
with tempfile.TemporaryDirectory() as work_dir:
# postprocs=lambda: [] is important to prevent default post processors from
# tampering with the manual schedule.
database = tune_extracted_tasks(
tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: []
)
else:
database = JSONDatabase(
path_workload=osp.join(work_dir, "database_workload.json"),
path_tuning_record=osp.join(work_dir, "database_tuning_record.json"),
)
else:

def schedule_fn(task, sch):
if "dense" not in task.task_name:
return False

for task in tune_tasks:
mod = Parse._mod(task.dispatched[0])
workload = database.commit_workload(mod)
block = sch.get_block("compute")

sch = tvm.tir.Schedule(mod)
block = sch.get_block("compute")
# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]

# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]
assert "dense_vnni" in schedule_rule

if "dense_vnni" in schedule_rule:
schedule_dense(block, M, False, sch)
schedule_dense(block, M, False, sch)

# [0.0] is for dummy measurement. There is only one tuning record so ApplyHistoryBest
# will always have only one option.
tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
return True

database.commit_tuning_record(tune_rec)
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)

with ApplyHistoryBest(database):
with tvm.transform.PassContext(
Expand Down
121 changes: 0 additions & 121 deletions tests/python/unittest/test_meta_schedule_tune_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,127 +91,6 @@ def test_tune_matmul_cuda():
print(sch.trace)


@pytest.mark.skip("Integeration test")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why this test is removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test depends on auto-tensorization for tensorcore, which is not in main

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We may upstream it later then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops...Thanks for spotting this!

def test_tune_matmul_cuda_tensor_core():
n = 512
mod = create_prim_func(te_workload.matmul_fp16(n, n, n))
target = Target("nvidia/geforce-rtx-3070")
config = ReplayTraceConfig(
num_trials_per_iter=32,
max_trials_per_task=320,
max_trials_global=320,
)

class DefaultTensorCore:
@staticmethod
def _sch_rules():
from tvm.meta_schedule import (
schedule_rule as M, # pylint: disable=import-outside-toplevel
)

return [
M.AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
),
M.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
# use_tensor_core=True,
max_innermost_factor=64,
vector_load_lens=[1, 2, 3, 4],
reuse_read=schedule_rule.ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=schedule_rule.ReuseType(
req="no",
levels=[],
scope="",
),
),
M.AutoInline(
into_producer=True,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=-1, # disable parallelize
max_vectorize_extent=-1, # disable vectorize
unroll_max_steps=[0, 16, 64, 512, 1024],
unroll_explicit=True,
),
]

@staticmethod
def _postproc():
from tvm.meta_schedule import (
postproc as M, # pylint: disable=import-outside-toplevel
)

return [
M.RewriteCooperativeFetch(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
M.RewriteTensorCore(),
M.VerifyGPUCode(),
]

with tempfile.TemporaryDirectory() as work_dir:
sch: Schedule = tune_tir(
mod=mod,
target=target,
config=config,
work_dir=work_dir,
space=PostOrderApply(),
sch_rules=DefaultTensorCore._sch_rules,
postprocs=DefaultTensorCore._postproc,
num_threads=None,
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)

import numpy as np
from tvm.contrib import nvcc

ctx = tvm.gpu(0)
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext():
func = tvm.build(sch.mod["main"], [], "cuda")
print(sch.mod.script())
print(func.imported_modules[0].get_source())
a_np = np.random.uniform(size=(n, n)).astype("float16")
b_np = np.random.uniform(size=(n, n)).astype("float16")
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx)
evaluator = func.time_evaluator(
func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40
)
print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3))

np.testing.assert_allclose(
c.asnumpy(),
np.matmul(a_np.astype("float32"), b_np.astype("float32")),
rtol=1e-4,
atol=1e-4,
)


if __name__ == """__main__""":
test_tune_matmul_cpu()
test_tune_matmul_cuda()
test_tune_matmul_cuda_tensor_core()