-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metaschedule] Add utility API to ease using manual schedules #10876
Changes from all commits
a095fc0
65733d1
3f73e39
e242fb4
86d7060
c5ff0fe
bdfd8e4
3638c1f
55ccdbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,14 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Testing utilitiy functions in meta schedule""" | ||
from typing import List, Optional | ||
import random | ||
from typing import List, Optional, Callable, Dict, Union | ||
|
||
import tvm | ||
|
||
from tvm.relay import Function as RelayFunc | ||
from tvm.tir import Schedule | ||
from tvm.target import Target | ||
from tvm.runtime import NDArray | ||
from tvm.meta_schedule import TuneContext # pylint: disable=unused-import | ||
from tvm.meta_schedule.utils import derived_object | ||
from tvm.meta_schedule.mutator.mutator import PyMutator | ||
|
@@ -32,6 +35,9 @@ | |
PyRunnerFuture, | ||
PyRunner, | ||
) | ||
from tvm.meta_schedule.tune import Parse, extract_task_from_relay | ||
from tvm.meta_schedule.integration import ExtractedTask | ||
|
||
from tvm.ir import IRModule | ||
from tvm.tir.schedule import Trace | ||
|
||
|
@@ -110,3 +116,46 @@ def initialize_with_tune_context(self, context: "TuneContext") -> None: | |
|
||
def apply(self, trace: Trace, _) -> Optional[Trace]: | ||
return Trace(trace.insts, {}) | ||
|
||
|
||
def apply_fixed_schedules( | ||
relay_mod: Union[RelayFunc, IRModule], | ||
target: Union[str, Target], | ||
params: Optional[Dict[str, NDArray]], | ||
schedule_fn: Callable[[ExtractedTask, Schedule], bool], | ||
): | ||
"""Apply fixed schedules (manually written, without any tunable knobs) as specified by | ||
schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. | ||
Parameters | ||
---------- | ||
mod : Union[RelayFunc, IRModule] | ||
The Relay module to apply fixed schedules. | ||
target : Union[str, Target] | ||
The target used to extract tasks. | ||
params : Optional[Dict[str, tvm.runtime.NDArray]] | ||
The associated parameters of the module. | ||
schedule_fn : Callable[[ExtractedTask, Schedule], bool] | ||
A callable that is applied for each extracted task and the corresponding default schedule. | ||
Returns True if the given schedule should be committed to the database, False otherwise. | ||
Returns | ||
------- | ||
database : Database | ||
The database containing dummy tuning records for manually scheduled traces. | ||
""" | ||
target = Target(target) if isinstance(target, str) else target | ||
extracted_tasks = extract_task_from_relay(relay_mod, target, params) | ||
|
||
database = DummyDatabase() | ||
|
||
for task in extracted_tasks: | ||
mod = Parse._mod(task.dispatched[0]) | ||
sch = Schedule(mod) | ||
|
||
if schedule_fn(task, sch): | ||
workload = database.commit_workload(mod) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. The purpose of this workload commit is to match against the unmodified mod during There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is correct, thank you. |
||
tune_rec = TuningRecord(sch.trace, [0.0], workload, target, []) | ||
database.commit_tuning_record(tune_rec) | ||
|
||
return database |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,127 +91,6 @@ def test_tune_matmul_cuda(): | |
print(sch.trace) | ||
|
||
|
||
@pytest.mark.skip("Integeration test") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I ask why this test is removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test depends on auto-tensorization for tensorcore, which is not in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. We may upstream it later then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooops...Thanks for spotting this! |
||
def test_tune_matmul_cuda_tensor_core(): | ||
n = 512 | ||
mod = create_prim_func(te_workload.matmul_fp16(n, n, n)) | ||
target = Target("nvidia/geforce-rtx-3070") | ||
config = ReplayTraceConfig( | ||
num_trials_per_iter=32, | ||
max_trials_per_task=320, | ||
max_trials_global=320, | ||
) | ||
|
||
class DefaultTensorCore: | ||
@staticmethod | ||
def _sch_rules(): | ||
from tvm.meta_schedule import ( | ||
schedule_rule as M, # pylint: disable=import-outside-toplevel | ||
) | ||
|
||
return [ | ||
M.AutoInline( | ||
into_producer=False, | ||
into_consumer=True, | ||
inline_const_tensor=True, | ||
disallow_if_then_else=False, | ||
require_injective=False, | ||
require_ordered=False, | ||
disallow_op=None, | ||
), | ||
M.MultiLevelTiling( | ||
structure="SSSRRSRS", | ||
tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], | ||
# use_tensor_core=True, | ||
max_innermost_factor=64, | ||
vector_load_lens=[1, 2, 3, 4], | ||
reuse_read=schedule_rule.ReuseType( | ||
req="must", | ||
levels=[4], | ||
scope="shared", | ||
), | ||
reuse_write=schedule_rule.ReuseType( | ||
req="no", | ||
levels=[], | ||
scope="", | ||
), | ||
), | ||
M.AutoInline( | ||
into_producer=True, | ||
into_consumer=True, | ||
inline_const_tensor=True, | ||
disallow_if_then_else=False, | ||
require_injective=False, | ||
require_ordered=False, | ||
disallow_op=None, | ||
), | ||
M.ParallelizeVectorizeUnroll( | ||
max_jobs_per_core=-1, # disable parallelize | ||
max_vectorize_extent=-1, # disable vectorize | ||
unroll_max_steps=[0, 16, 64, 512, 1024], | ||
unroll_explicit=True, | ||
), | ||
] | ||
|
||
@staticmethod | ||
def _postproc(): | ||
from tvm.meta_schedule import ( | ||
postproc as M, # pylint: disable=import-outside-toplevel | ||
) | ||
|
||
return [ | ||
M.RewriteCooperativeFetch(), | ||
M.RewriteParallelVectorizeUnroll(), | ||
M.RewriteReductionBlock(), | ||
M.RewriteTensorCore(), | ||
M.VerifyGPUCode(), | ||
] | ||
|
||
with tempfile.TemporaryDirectory() as work_dir: | ||
sch: Schedule = tune_tir( | ||
mod=mod, | ||
target=target, | ||
config=config, | ||
work_dir=work_dir, | ||
space=PostOrderApply(), | ||
sch_rules=DefaultTensorCore._sch_rules, | ||
postprocs=DefaultTensorCore._postproc, | ||
num_threads=None, | ||
) | ||
if sch is None: | ||
print("No valid schedule found!") | ||
else: | ||
print(sch.mod.script()) | ||
print(sch.trace) | ||
|
||
import numpy as np | ||
from tvm.contrib import nvcc | ||
|
||
ctx = tvm.gpu(0) | ||
if nvcc.have_tensorcore(ctx.compute_version): | ||
with tvm.transform.PassContext(): | ||
func = tvm.build(sch.mod["main"], [], "cuda") | ||
print(sch.mod.script()) | ||
print(func.imported_modules[0].get_source()) | ||
a_np = np.random.uniform(size=(n, n)).astype("float16") | ||
b_np = np.random.uniform(size=(n, n)).astype("float16") | ||
a = tvm.nd.array(a_np, ctx) | ||
b = tvm.nd.array(b_np, ctx) | ||
c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx) | ||
evaluator = func.time_evaluator( | ||
func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40 | ||
) | ||
print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3)) | ||
|
||
np.testing.assert_allclose( | ||
c.asnumpy(), | ||
np.matmul(a_np.astype("float32"), b_np.astype("float32")), | ||
rtol=1e-4, | ||
atol=1e-4, | ||
) | ||
|
||
|
||
if __name__ == """__main__""": | ||
test_tune_matmul_cpu() | ||
test_tune_matmul_cuda() | ||
test_tune_matmul_cuda_tensor_core() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about we change it to
Callable[[ExtractedTask], Schedule]
, i.e., for input the dispatched IRModule is inside of task, and we can return a schedule if it matches our rule, otherwise return None. I think it might be better to avoid confusion.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's possible, but that will force users to write
in every
schedule_fn
callback. I think this boilerplate is non-trivial (users shouldn't care aboutdispatched
orParse
stuff).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That make sense, and I wonder if using the task name and schedule would suffice, if we don't want users to care about details inside of the extracted task.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some cases where having the relay mod is required. For example, I want to be able to skip tasks based on the output dtype of the compute, which can be retrieved from the relay mod.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Thanks for the explaination.