Skip to content

Commit

Permalink
Fix builder.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Feb 24, 2022
1 parent 92416f9 commit a5df3e5
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 42 deletions.
48 changes: 38 additions & 10 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule builders that translate IRModule to runtime.Module, and then export"""
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import NDArray, Object
from tvm.target import Target

from .. import _ffi_api
from ..utils import check_override


@register_object("meta_schedule.BuilderInput")
Expand Down Expand Up @@ -124,17 +123,46 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:


@register_object("meta_schedule.PyBuilder")
class PyBuilder(Builder):
"""An abstract builder with customized build method on the python-side."""
class _PyBuilder(Builder):
"""
A TVM object builder to support customization on the python side.
This is NOT the user facing class for function overloading inheritance.
def __init__(self):
"""Constructor."""
See also: PyBuilder
"""

@check_override(self.__class__, Builder)
def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return self.build(build_inputs)
def __init__(self, methods: List[Callable]):
"""Constructor."""

self.__init_handle_by_constructor__(
_ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member
f_build,
*methods,
)


class PyBuilder:
"""
An abstract builder with customized run method on the python-side.
This is the user facing class for function overloading inheritance.
Note: @derived_object is required for proper usage of any inherited class.
"""

tvm_metadata = {
"cls": _PyBuilder,
"methods": ["build"],
}

def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
"""Build the given inputs.
Parameters
----------
build_inputs : List[BuilderInput]
The inputs to be built.
Returns
-------
build_results : List[BuilderResult]
The results of building the given inputs.
"""
raise NotImplementedError
64 changes: 33 additions & 31 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@
from tvm.target import Target

from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from ..utils import cpu_count, get_global_func_with_default_on_worker
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


T_BUILD = Callable[[IRModule, Target, Optional[Dict[str, NDArray]]], Module]
T_EXPORT = Callable[[Module], str]


def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]:
if params is None:
return None
Expand All @@ -45,6 +49,7 @@ def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArr
return load_param_dict(params)


@derived_object
class LocalBuilder(PyBuilder):
"""A builder that builds the given input on local host.
Expand All @@ -54,10 +59,10 @@ class LocalBuilder(PyBuilder):
The process pool to run the build.
timeout_sec : float
The timeout in seconds for the build.
f_build : Union[None, str, LocalBuilder.T_BUILD]
f_build : Union[None, str, T_BUILD]
Name of the build function to be used.
Defaults to `meta_schedule.builder.default_build`.
f_export : Union[None, str, LocalBuilder.T_EXPORT]
f_export : Union[None, str, T_EXPORT]
Name of the export function to be used.
Defaults to `meta_schedule.builder.default_export`.
Expand Down Expand Up @@ -91,9 +96,6 @@ def default_export(mod: Module) -> str:
please send the registration logic via initializer.
"""

T_BUILD = Callable[[IRModule, Target, Optional[Dict[str, NDArray]]], Module]
T_EXPORT = Callable[[Module], str]

pool: PopenPoolExecutor
timeout_sec: float
f_build: Union[None, str, T_BUILD]
Expand All @@ -117,10 +119,10 @@ def __init__(
Defaults to number of CPUs.
timeout_sec : float
The timeout in seconds for the build.
f_build : LocalBuilder.T_BUILD
f_build : T_BUILD
Name of the build function to be used.
Defaults to `meta_schedule.builder.default_build`.
f_export : LocalBuilder.T_EXPORT
f_export : T_EXPORT
Name of the export function to be used.
Defaults to `meta_schedule.builder.default_export`.
initializer : Optional[Callable[[], None]]
Expand Down Expand Up @@ -148,7 +150,7 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:

# Dispatch the build inputs to the worker processes.
for map_result in self.pool.map_with_error_catching(
lambda x: LocalBuilder._worker_func(*x),
lambda x: _worker_func(*x),
[
(
self.f_build,
Expand Down Expand Up @@ -188,28 +190,28 @@ def _check(f_build, f_export) -> None:
value = self.pool.submit(_check, self.f_build, self.f_export)
value.result()

@staticmethod
def _worker_func(
_f_build: Union[None, str, T_BUILD],
_f_export: Union[None, str, T_EXPORT],
mod: IRModule,
target: Target,
params: Optional[bytearray],
) -> str:
# Step 0. Get the registered functions
f_build: LocalBuilder.T_BUILD = get_global_func_with_default_on_worker(
_f_build,
default_build,
)
f_export: LocalBuilder.T_EXPORT = get_global_func_with_default_on_worker(
_f_export,
default_export,
)
# Step 1. Build the IRModule
rt_mod: Module = f_build(mod, target, _deserialize_params(params))
# Step 2. Export the Module
artifact_path: str = f_export(rt_mod)
return artifact_path

def _worker_func(
_f_build: Union[None, str, T_BUILD],
_f_export: Union[None, str, T_EXPORT],
mod: IRModule,
target: Target,
params: Optional[bytearray],
) -> str:
# Step 0. Get the registered functions
f_build: T_BUILD = get_global_func_with_default_on_worker(
_f_build,
default_build,
)
f_export: T_EXPORT = get_global_func_with_default_on_worker(
_f_export,
default_export,
)
# Step 1. Build the IRModule
rt_mod: Module = f_build(mod, target, _deserialize_params(params))
# Step 2. Export the Module
artifact_path: str = f_export(rt_mod)
return artifact_path


@register_func("meta_schedule.builder.default_build")
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,16 @@ class PyRunner:
}

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
raise NotImplementedError()
"""Run the built artifact and get runner futures.
Parameters
----------
runner_inputs : List[RunnerInput]
The inputs to the runner.
Returns
-------
runner_futures: List[RunnerFuture]
The runner futures.
"""
raise NotImplementedError
1 change: 1 addition & 0 deletions tests/python/unittest/test_meta_schedule_task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def result(self) -> RunnerResult:
return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None)


@derived_object
class DummyBuilder(PyBuilder):
def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return [BuilderResult("test_path", None) for _ in build_inputs]
Expand Down

0 comments on commit a5df3e5

Please sign in to comment.