Skip to content

Commit

Permalink
[MetaSchedule] Add utility API to ease using manual schedules (apache…
Browse files Browse the repository at this point in the history
…#10876)

As discussed in apache#10856 (comment), add a utility under `meta_schedule/testing/utils.py` to clean up the database boilerplate. Also using `DummyDatabase` instead of `JsonDatabase` for further clean up, as suggested by @junrushao1994 .
  • Loading branch information
masahi authored and Lucien0 committed Apr 19, 2022
1 parent 1408889 commit f8ea637
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 188 deletions.
9 changes: 8 additions & 1 deletion python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from .utils import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture, DummyMutator
from .utils import (
DummyDatabase,
DummyBuilder,
DummyRunner,
DummyRunnerFuture,
DummyMutator,
apply_fixed_schedules,
)
53 changes: 51 additions & 2 deletions python/tvm/meta_schedule/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilitiy functions in meta schedule"""
from typing import List, Optional
import random
from typing import List, Optional, Callable, Dict, Union

import tvm

from tvm.relay import Function as RelayFunc
from tvm.tir import Schedule
from tvm.target import Target
from tvm.runtime import NDArray
from tvm.meta_schedule import TuneContext # pylint: disable=unused-import
from tvm.meta_schedule.utils import derived_object
from tvm.meta_schedule.mutator.mutator import PyMutator
Expand All @@ -32,6 +35,9 @@
PyRunnerFuture,
PyRunner,
)
from tvm.meta_schedule.tune import Parse, extract_task_from_relay
from tvm.meta_schedule.integration import ExtractedTask

from tvm.ir import IRModule
from tvm.tir.schedule import Trace

Expand Down Expand Up @@ -110,3 +116,46 @@ def initialize_with_tune_context(self, context: "TuneContext") -> None:

def apply(self, trace: Trace, _) -> Optional[Trace]:
return Trace(trace.insts, {})


def apply_fixed_schedules(
relay_mod: Union[RelayFunc, IRModule],
target: Union[str, Target],
params: Optional[Dict[str, NDArray]],
schedule_fn: Callable[[ExtractedTask, Schedule], bool],
):
"""Apply fixed schedules (manually written, without any tunable knobs) as specified by
schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.
Parameters
----------
mod : Union[RelayFunc, IRModule]
The Relay module to apply fixed schedules.
target : Union[str, Target]
The target used to extract tasks.
params : Optional[Dict[str, tvm.runtime.NDArray]]
The associated parameters of the module.
schedule_fn : Callable[[ExtractedTask, Schedule], bool]
A callable that is applied for each extracted task and the corresponding default schedule.
Returns True if the given schedule should be committed to the database, False otherwise.
Returns
-------
database : Database
The database containing dummy tuning records for manually scheduled traces.
"""
target = Target(target) if isinstance(target, str) else target
extracted_tasks = extract_task_from_relay(relay_mod, target, params)

database = DummyDatabase()

for task in extracted_tasks:
mod = Parse._mod(task.dispatched[0])
sch = Schedule(mod)

if schedule_fn(task, sch):
workload = database.commit_workload(mod)
tune_rec = TuningRecord(sch.trace, [0.0], workload, target, [])
database.commit_tuning_record(tune_rec)

return database
39 changes: 8 additions & 31 deletions tests/python/unittest/test_meta_schedule_multi_anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tempfile

import numpy as np

import tvm
import tvm.testing
from tvm import relay
from tvm.meta_schedule.tune import Parse, extract_task_from_relay
from tvm.meta_schedule.database import TuningRecord, JSONDatabase
from tvm.meta_schedule.testing import apply_fixed_schedules
from tvm.meta_schedule.integration import ApplyHistoryBest


Expand Down Expand Up @@ -72,39 +68,20 @@ def test_dense_dense():

# print(relay.transform.InferType()(relay_mod))

target = "llvm"

data_np = np.random.randn(*data_shape).astype("float32")
weight1_np = np.random.randn(*weight_shape).astype("float32")
weight2_np = np.random.randn(*weight_shape).astype("float32")

target = "llvm"
params = {"weight1": weight1_np, "weight2": weight2_np}

extracted_tasks = extract_task_from_relay(relay_mod, target, params)

assert len(extracted_tasks) == 1

task = extracted_tasks[0]

mod = Parse._mod(task.dispatched[0])

with tempfile.TemporaryDirectory() as work_dir:
database = JSONDatabase(
path_workload=os.path.join(work_dir, "database_workload.json"),
path_tuning_record=os.path.join(work_dir, "database_tuning_record.json"),
)

workload = database.commit_workload(mod)

sch = tvm.tir.Schedule(mod)

schedule_dense_dense(sch)

# print(sch.mod.script())

tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
def schedule_fn(task, sch):
if "nn_dense_nn_dense" in task.task_name:
schedule_dense_dense(sch)
return True
return False

database.commit_tuning_record(tune_rec)
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)

with ApplyHistoryBest(database):
with tvm.transform.PassContext(
Expand Down
60 changes: 27 additions & 33 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload
from tvm.meta_schedule.integration import ApplyHistoryBest
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing import apply_fixed_schedules
from tvm.meta_schedule.tune import (
Parse,
extract_task_from_relay,
tune_extracted_tasks,
tune_relay,
Expand Down Expand Up @@ -480,52 +480,46 @@ def manual_tir_common(do_tune=False):

params = {"weight": weight_np, "bias": bias_np}

extracted_tasks = extract_task_from_relay(relay_mod, target, params)
if do_tune:
extracted_tasks = extract_task_from_relay(relay_mod, target, params)

# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
extracted_tasks,
# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
extracted_tasks,
)
)
config = ReplayTraceConfig(
num_trials_per_iter=64,
max_trials_per_task=64,
max_trials_global=20000,
)
)

with tempfile.TemporaryDirectory() as work_dir:
if do_tune:
config = ReplayTraceConfig(
num_trials_per_iter=64,
max_trials_per_task=64,
max_trials_global=20000,
)
with tempfile.TemporaryDirectory() as work_dir:
# postprocs=lambda: [] is important to prevent default post processors from
# tampering with the manual schedule.
database = tune_extracted_tasks(
tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: []
)
else:
database = JSONDatabase(
path_workload=osp.join(work_dir, "database_workload.json"),
path_tuning_record=osp.join(work_dir, "database_tuning_record.json"),
)
else:

def schedule_fn(task, sch):
if "dense" not in task.task_name:
return False

for task in tune_tasks:
mod = Parse._mod(task.dispatched[0])
workload = database.commit_workload(mod)
block = sch.get_block("compute")

sch = tvm.tir.Schedule(mod)
block = sch.get_block("compute")
# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]

# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]
assert "dense_vnni" in schedule_rule

if "dense_vnni" in schedule_rule:
schedule_dense(block, M, False, sch)
schedule_dense(block, M, False, sch)

# [0.0] is for dummy measurement. There is only one tuning record so ApplyHistoryBest
# will always have only one option.
tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
return True

database.commit_tuning_record(tune_rec)
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)

with ApplyHistoryBest(database):
with tvm.transform.PassContext(
Expand Down
121 changes: 0 additions & 121 deletions tests/python/unittest/test_meta_schedule_tune_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,127 +91,6 @@ def test_tune_matmul_cuda():
print(sch.trace)


@pytest.mark.skip("Integeration test")
def test_tune_matmul_cuda_tensor_core():
n = 512
mod = create_prim_func(te_workload.matmul_fp16(n, n, n))
target = Target("nvidia/geforce-rtx-3070")
config = ReplayTraceConfig(
num_trials_per_iter=32,
max_trials_per_task=320,
max_trials_global=320,
)

class DefaultTensorCore:
@staticmethod
def _sch_rules():
from tvm.meta_schedule import (
schedule_rule as M, # pylint: disable=import-outside-toplevel
)

return [
M.AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
),
M.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
# use_tensor_core=True,
max_innermost_factor=64,
vector_load_lens=[1, 2, 3, 4],
reuse_read=schedule_rule.ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=schedule_rule.ReuseType(
req="no",
levels=[],
scope="",
),
),
M.AutoInline(
into_producer=True,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=-1, # disable parallelize
max_vectorize_extent=-1, # disable vectorize
unroll_max_steps=[0, 16, 64, 512, 1024],
unroll_explicit=True,
),
]

@staticmethod
def _postproc():
from tvm.meta_schedule import (
postproc as M, # pylint: disable=import-outside-toplevel
)

return [
M.RewriteCooperativeFetch(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
M.RewriteTensorCore(),
M.VerifyGPUCode(),
]

with tempfile.TemporaryDirectory() as work_dir:
sch: Schedule = tune_tir(
mod=mod,
target=target,
config=config,
work_dir=work_dir,
space=PostOrderApply(),
sch_rules=DefaultTensorCore._sch_rules,
postprocs=DefaultTensorCore._postproc,
num_threads=None,
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)

import numpy as np
from tvm.contrib import nvcc

ctx = tvm.gpu(0)
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext():
func = tvm.build(sch.mod["main"], [], "cuda")
print(sch.mod.script())
print(func.imported_modules[0].get_source())
a_np = np.random.uniform(size=(n, n)).astype("float16")
b_np = np.random.uniform(size=(n, n)).astype("float16")
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx)
evaluator = func.time_evaluator(
func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40
)
print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3))

np.testing.assert_allclose(
c.asnumpy(),
np.matmul(a_np.astype("float32"), b_np.astype("float32")),
rtol=1e-4,
atol=1e-4,
)


if __name__ == """__main__""":
test_tune_matmul_cpu()
test_tune_matmul_cuda()
test_tune_matmul_cuda_tensor_core()

0 comments on commit f8ea637

Please sign in to comment.