From 1d302204f284138ed8b8f65ec02d7a5f412d84ec Mon Sep 17 00:00:00 2001 From: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Mon, 14 Feb 2022 16:16:15 -0800 Subject: [PATCH] AutoTIR integration (#58) * [WIP] Basic task extraction mechanism is implemented. * [WIP] For gradual integration with Relay pipeline, meta_schedule/integration.py is created for relax to avoid potential conflict. * support tir tuning and injection mode * Add target field for Relax Extracted Task * 1. Create relax namespace/tvm objects/... for metaschedule to preserve relay support. 2. Promote target field from Optional to Target * Support ApplyHistoryBest * Reflect feedback from Yuchen * minor improvement and fix linter issue * add ASF header * Reorganize file structure * fix lint errors * remove the import-outside-toplevel * Reflect comments * remove redundant comment * As per discussion w/ Yuchen, ApplyHistoryBest is introduced as a Relax transformation pass. * remove redundant print msg * fix lint * reflect comments --- 3rdparty/cutlass | 2 +- include/tvm/relax/transform.h | 9 + python/tvm/meta_schedule/integration.py | 316 ++++++++++++++++++ python/tvm/relax/__init__.py | 1 - python/tvm/relax/transform/transform.py | 20 ++ python/tvm/relax/utils.py | 30 ++ python/tvm/relax/vm.py | 2 - src/meta_schedule/apply_history_best.cc | 1 - src/relax/backend/vm/vm_shape_lower.cc | 2 +- src/relax/transform/meta_schedule_ahb.cc | 79 +++++ .../python/relax/test_autotir_integration.py | 213 ++++++++++++ 11 files changed, 669 insertions(+), 6 deletions(-) create mode 100644 python/tvm/meta_schedule/integration.py create mode 100644 python/tvm/relax/utils.py create mode 100644 src/relax/transform/meta_schedule_ahb.cc create mode 100644 tests/python/relax/test_autotir_integration.py diff --git a/3rdparty/cutlass b/3rdparty/cutlass index c2ee13a0fe99..a3bcc6981d5d 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit c2ee13a0fe99241b0e798ce647acf98e237f1d0c +Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index c4477249d425..877c79ed3afd 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -25,6 +25,7 @@ #define TVM_RELAX_TRANSFORM_H_ #include +#include #include namespace tvm { @@ -78,6 +79,14 @@ TVM_DLL Pass CallTIRRewrite(); */ TVM_DLL Pass ToANF(); +/*! + * \brief Apply the best schedule from tuning database. + * + * \return The Pass. + */ +TVM_DLL Pass MetaScheduleApplyHistoryBest(const tvm::meta_schedule::Database& database, + Target target); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py new file mode 100644 index 000000000000..f0ee92b55bc9 --- /dev/null +++ b/python/tvm/meta_schedule/integration.py @@ -0,0 +1,316 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from typing import Dict, List, Optional, Union + +import numpy as np # type: ignore +import tvm.runtime.ndarray as nd + +from tvm._ffi import register_object, get_global_func +from tvm.ir import IRModule, transform +from tvm.relay import Any +from tvm.relay import Function as RelayFunc +from tvm.runtime import NDArray, Object +from tvm.target import Target +from tvm.tir import PrimFunc +from tvm.relax.expr import Function as RelaxFunc +from tvm.relax.utils import tir_partitioner +from tvm.relax.ty import DynTensorType + +from . import _ffi_api +from .database import Database + + +@register_object("meta_schedule.ExtractedTask") +class ExtractedTask(Object): + """A tuning task extracted from the high-level IR + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + target: Target + Target information + dispatched : List[IRModule] + A list of low-level IRs that the high-level IR could potentially dispatch to + """ + + task_name: str + mod: IRModule + dispatched: List[IRModule] + + def __init__( + self, + task_name: str, + mod: IRModule, + target: Target, + dispatched: List[IRModule], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member + task_name, + mod, + target, + dispatched, + ) + + +@register_object("meta_schedule.MetaScheduleContext") +class MetaScheduleContext(Object): + """A context manager interface for the integration""" + + def query( + self, + task_name: str, + mod: IRModule, + target: Target, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, None]: + """The entry point of the integration + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + target: Target + Target Info + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : IRModule or None + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use. None is returned if there is no feedback hint. + """ + return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member + self, + task_name, + mod, + target, + dispatched, + ) + + @staticmethod + def current() -> Optional["MetaScheduleContext"]: + """The context manager in the current scope + + Returns + ------- + ctx : Optional[MetaScheduleContext] + The MetaScheduleContext in the current scope. + NullOpt if it's currently not under any MetaScheduleContext. + """ + return _ffi_api.MetaScheduleContextCurrent() # type: ignore # pylint: disable=no-member + + @staticmethod + def query_inside_with_scope( + task_name: str, + mod: IRModule, + target: Target, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, None]: + """The entry point of the integration workflow. The compilation process of the high-level + IR should call this method for task extraction and for feedback hints + + Basically, this method is equivalent to: + + .. code-block:: python + + def query_inside_with_scope(task_name, mod, dispatched): + ctx = MetaScheduleContext.current() + assert ctx is not None + mod = ctx.query(task_name, mod, target, dispatched) + + Parameters + ---------- + task_name : str + The name of the task + mod : IRModule + The high-level IR + target: Target + Target + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : IRModule or None + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use. None is returned if there is no feedback hint. + """ + return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member + task_name, + mod, + target, + dispatched, + ) + + def __enter__(self) -> "MetaScheduleContext": + """Entering the scope of the context manager""" + _ffi_api.MetaScheduleContextEnterScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, ptype, value, trace) -> None: + """Exiting the scope of the context manager""" + _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.ApplyHistoryBest") +class ApplyHistoryBest(MetaScheduleContext): + """An integration context that allows application of historically best record from database""" + + database: Database + """ The database to be queried from""" + + def __init__(self, database) -> None: + self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member + + +def extract_task_from_relay( + mod: Union[IRModule, RelayFunc], + target: Target, + params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relay program. + + Parameters + ---------- + mod : Union[tvm.IRModule, tvm.relay.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + opt_level : int + The optimization level of the compiler + pass_config : Optional[Dict[str, Any]] + The pass config of the compiler + disabled_pass : Optional[List[str]] + The list of disabled passes of the compiler + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this network + """ + + extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask") + assert extract_task_func + + target = Target(target) if isinstance(target, str) else target + + relay_params = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = nd.array(param) + relay_params[name] = param + + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = {"relay.backend.use_meta_schedule": True} + + if isinstance(mod, RelayFunc): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + + with target, transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + tasks = extract_task_func(mod, target, relay_params) + # Tasks are extracted via post order visit, return the reversed list. + return list(reversed(tasks)) + + +def extract_task_from_relax( + mod: Union[IRModule, RelaxFunc], + target: Target, + *, + opt_level: int = 3, + pass_config: Dict[str, DynTensorType] = {}, + disabled_pass: List[str] = [], +) -> List[ExtractedTask]: + """Extract tuning tasks from a relax program. + + Parameters + ---------- + mod : Union[tvm.IRModule, tvm.relax.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + opt_level : int + The optimization level of the compiler + pass_config : Dict[str, DynTensorType] + The pass config of the compiler + disabled_pass : List[str] + The list of disabled passes of the compiler + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this network + """ + + @contextmanager + def _autotvm_silencer(): + from tvm import autotvm # pylint: disable=import-outside-toplevel + + silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + try: + yield + finally: + autotvm.GLOBAL_SCOPE.silent = silent + + def _thread_run(func: Callable[[], None]) -> None: + import threading # pylint: disable=import-outside-toplevel + + thread = threading.Thread(target=func) + thread.start() + thread.join() + + env = TaskExtraction() + if isinstance(mod, RelaxFunc): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + + def _func(): + with env, _autotvm_silencer(), transform.PassContext( + config=pass_config, + disabled_pass=disabled_pass, + opt_level=opt_level, + ): + tir_partitions = tir_partitioner(mod) + for tir_mod in tir_partitions: + func_name = tir_mod.get_global_vars()[0].name_hint + MetaScheduleContext.query_inside_with_scope(func_name, tir_mod, target, [tir_mod]) + + _thread_run(_func) + return env.tasks diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 6037b4135824..c603b28f1d8b 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -25,7 +25,6 @@ from . import analysis from . import transform - # Expr Expr = expr.Expr Span = expr.Span diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 891de1470d15..8dc396299026 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name """Relax transformation passes.""" import tvm.ir +from tvm.target import Target +from tvm.meta_schedule.database import PyDatabase from . import _ffi_api @@ -98,3 +100,21 @@ def ResolveGlobals() -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass """ return _ffi_api.ResolveGlobals() + + +def MetaScheduleApplyHistoryBest( + database: PyDatabase, + target: Target, +) -> tvm.ir.transform.Pass: + """Apply the best schedule from tuning database. + Parameters + ---------- + database : metaschedule tuning database + target: target info + + Returns + ------- + ret: tvm.ir.transform.Pass + + """ + return _ffi_api.MetaScheduleApplyHistoryBest(database, target) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py new file mode 100644 index 000000000000..137f70a06ca5 --- /dev/null +++ b/python/tvm/relax/utils.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for Relax""" +from typing import List +from tvm.tir import PrimFunc +from tvm import IRModule + +# Simply extracts tir PrimFuncs from the input IRModule +def tir_partitioner(mod: IRModule) -> List[IRModule]: + partitions = [] + for gvar in mod.get_global_vars(): + if isinstance(mod[gvar], PrimFunc): + tir_mod = IRModule({}) + tir_mod[gvar] = mod[gvar] + partitions.append(tir_mod) + return partitions diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 9ae9850c9731..79399723a421 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -21,7 +21,6 @@ from tvm import relax from tvm.ir.module import IRModule from tvm.runtime import Object, Device, Module, PackedFunc - from tvm.tir.function import PrimFunc from . import _ffi_api from ..rpc.base import RPC_SESS_MASK @@ -187,7 +186,6 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): # split primfunc and relax function rx_mod, tir_mod = _split_tir_relax(new_mod) - lib = tvm.build(tir_mod, target) ex = _ffi_api.VMCodeGen(rx_mod) return ex, lib diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc index 62db29306777..a5e27b0553f1 100644 --- a/src/meta_schedule/apply_history_best.cc +++ b/src/meta_schedule/apply_history_best.cc @@ -108,7 +108,6 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModu FDirectDispatch f_direct_dispatch) { ICHECK(dispatched.defined()); ICHECK_EQ(dispatched.value().size(), 1); - ICHECK(HasOnlyOneFunction(mod)) << mod; IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 8a7c44a2e2a7..043ec4bb4898 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -41,7 +41,7 @@ class VMShapeLowerMutator : public ExprMutator { IRModule Lower() { for (auto& p : mod_->functions) { Expr func = p.second; - if (p.second->IsInstance()) { + if (func->IsInstance()) { // prepare mapping and heap var expr2slot_ = PrepareExpr2Slot(Downcast(func)); heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); diff --git a/src/relax/transform/meta_schedule_ahb.cc b/src/relax/transform/meta_schedule_ahb.cc new file mode 100644 index 000000000000..542d5dd6f8c0 --- /dev/null +++ b/src/relax/transform/meta_schedule_ahb.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/meta_schedule_ahb.cc + * \brief Pass for applying the best schedule from tuning database. + */ + +#include + +namespace tvm { +namespace relax { + +class MetaScheduleAHB { + public: + explicit MetaScheduleAHB(IRModule mod, const tvm::meta_schedule::Database& db, Target target) + : mod_(mod), db_(db), target_(target) {} + IRModule Apply() { + ret_mod_ = IRModule(); + tvm::meta_schedule::ApplyHistoryBest ahb(db_); + for (auto& p : mod_->functions) { + GlobalVar gv = p.first; + BaseFunc func = p.second; + BaseFunc newfunc = func; + if (func->IsInstance()) { + IRModule tir_mod(Map({{gv, func}})); + ObjectRef res = ahb->Query(gv->name_hint, mod_, target_, Array{tir_mod}); + // replace the tir func only when the schedule is found in tuning database. + if (res.defined()) { + IRModule newmod = Downcast(res); + ICHECK_EQ(newmod->functions.size(), 1); + newfunc = (*newmod->functions.begin()).second; + } + } + + ret_mod_->Add(gv, newfunc); + } + return ret_mod_; + } + + private: + IRModule mod_; + const tvm::meta_schedule::Database& db_; + Target target_; + IRModule ret_mod_; +}; + +namespace transform { + +Pass MetaScheduleApplyHistoryBest(const tvm::meta_schedule::Database& database, Target target) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return MetaScheduleAHB(m, database, target).Apply(); }; + return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleApplyHistoryBest", + /*required*/ {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyHistoryBest") + .set_body_typed(MetaScheduleApplyHistoryBest); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py new file mode 100644 index 000000000000..f3ce56b8fc92 --- /dev/null +++ b/tests/python/relax/test_autotir_integration.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import tvm +from tvm.script import tir as T, relax as R +from tvm import relax +import numpy as np +from tvm.tir import Schedule +from tvm.ir.module import IRModule +from tvm.target.target import Target +import tempfile +from typing import List +from tvm.meta_schedule import ReplayTraceConfig, tune_tir +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.integration import extract_task_from_relax +from tvm import transform +import time + +# Test case with dynamic shape. +# Tuning with dynamic shape is not supported yet. +""" +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (n,k)) + C = T.match_buffer(z, (m,k)) + + for (i0, j0, k0) in T.grid(m,n,k): + with T.block(): + i,j,k = T.axis.remap("SSR", [i0,j0,k0]) + with T.init(): + C[i,j] = 0.0 + C[i,j] += A[i,k] * B[j,k] + + @T.prim_func + def tir_relu(x:T.handle, y:T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (m,n)) + for (i,j) in T.grid(m,n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x:Tensor[(m,n), "float32"], w:Tensor[(n,k), "float32"]) -> Tensor: + with R.dataflow(): + sh = relax.call_packed("vm.builtin.shape_of", x) + x0 = relax.match_shape(sh, (m, n)) + sh1 = relax.call_packed("vm.builtin.shape_of", w) + x1 = relax.match_shape(sh1, (n, k)) + lv0 = R.call_tir((m,k), tir_matmul, (x,w)) + lv1 = R.call_tir((m,k), tir_relu, (lv0)) + relax.output(lv1) + return lv1 +""" + + +class DummyDatabase(PyDatabase): + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + +def test_class_irmodule(dev: str): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: Tensor[(32, 32), "float32"], w: Tensor[(32, 32), "float32"]) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir((32, 32), tir_matmul, (x, w)) + lv1 = R.call_tir((32, 32), tir_relu, (lv0)) + relax.output(lv1) + return lv1 + + mod = InputModule + assert isinstance(mod, tvm.IRModule) + + if dev == "cpu": + target = Target("llvm --num-cores=16") + dev = tvm.cpu() + else: + target = Target("nvidia/geforce-rtx-3070") + dev = tvm.cuda() + + database = DummyDatabase() + tasks = extract_task_from_relax(mod, target=target) + for task in tasks: + print(f"Extracted task: {task.task_name}, {task.target}") + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=task.mod, + target=target, + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + work_dir=work_dir, + database=database, + ) + + with transform.PassContext(opt_level=3): + ex0, lib0 = relax.vm.build(mod, target) + + with transform.PassContext(opt_level=3): + mod = relax.transform.MetaScheduleApplyHistoryBest(database, target)(mod) + ex1, lib1 = relax.vm.build(mod, target) + + vm0 = relax.VirtualMachine(ex0, dev, mod=lib0) + vm1 = relax.VirtualMachine(ex1, dev, mod=lib1) + data = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + + # Measure the performance w/o tuning log + tic = time.time() + vm0["main"](data, weight) + toc = time.time() + e0 = toc - tic + + # Measure the performance w/ tuning log + tic = time.time() + vm1["main"](data, weight) + toc = time.time() + e1 = toc - tic + + print(f"w/o tuning: {e0}") + print(f"w/ tuning: {e1}") + + +if __name__ == "__main__": + test_class_irmodule(dev="cpu")