Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Fix Cyclic Dependency in PyClass Family #10368

Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,18 @@ class RunnerFutureNode : public runtime::Object {
* \brief Check whether the runner has finished.
* \return A boolean indicating whether the runner has finished.
*/
bool Done() const { return f_done(); }
bool Done() const {
ICHECK(f_done != nullptr) << "PyRunnerFuture's Done method not implemented!";
return f_done();
}
/*!
* \brief Fetch the runner's output if it is ready.
* \return The runner's output.
*/
RunnerResult Result() const { return f_result(); }
RunnerResult Result() const {
ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!";
return f_result();
}

static constexpr const char* _type_key = "meta_schedule.RunnerFuture";
TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
<< "PySpaceGenerator's InitializeWithTuneContext method not implemented!";
f_initialize_with_tune_context(context);
}

Expand Down
45 changes: 35 additions & 10 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule builders that translate IRModule to runtime.Module, and then export"""
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import NDArray, Object
from tvm.target import Target

from .. import _ffi_api
from ..utils import check_override


@register_object("meta_schedule.BuilderInput")
Expand Down Expand Up @@ -125,17 +124,43 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:


@register_object("meta_schedule.PyBuilder")
class PyBuilder(Builder):
"""An abstract builder with customized build method on the python-side."""
class _PyBuilder(Builder):
"""
A TVM object builder to support customization on the python side.
This is NOT the user facing class for function overloading inheritance.

def __init__(self):
"""Constructor."""
See also: PyBuilder
"""

@check_override(self.__class__, Builder)
def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return self.build(build_inputs)
def __init__(self, methods: List[Callable]):
"""Constructor."""

self.__init_handle_by_constructor__(
_ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member
f_build,
*methods,
)


class PyBuilder:
"""
An abstract builder with customized build method on the python-side.
This is the user facing class for function overloading inheritance.

Note: @derived_object is required for proper usage of any inherited class.
"""

_tvm_metadata = {"cls": _PyBuilder, "methods": ["build"]}

def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
"""Build the given inputs.

Parameters
----------
build_inputs : List[BuilderInput]
The inputs to be built.
Returns
-------
build_results : List[BuilderResult]
The results of building the given inputs.
"""
raise NotImplementedError
66 changes: 35 additions & 31 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@
from tvm.target import Target

from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from ..utils import cpu_count, get_global_func_with_default_on_worker
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


T_BUILD = Callable[ # pylint: disable=invalid-name
[IRModule, Target, Optional[Dict[str, NDArray]]], Module
]
T_EXPORT = Callable[[Module], str] # pylint: disable=invalid-name


def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]:
if params is None:
return None
Expand All @@ -45,6 +51,7 @@ def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArr
return load_param_dict(params)


@derived_object
class LocalBuilder(PyBuilder):
"""A builder that builds the given input on local host.

Expand All @@ -54,10 +61,10 @@ class LocalBuilder(PyBuilder):
The process pool to run the build.
timeout_sec : float
The timeout in seconds for the build.
f_build : Union[None, str, LocalBuilder.T_BUILD]
f_build : Union[None, str, T_BUILD]
Name of the build function to be used.
Defaults to `meta_schedule.builder.default_build`.
f_export : Union[None, str, LocalBuilder.T_EXPORT]
f_export : Union[None, str, T_EXPORT]
Name of the export function to be used.
Defaults to `meta_schedule.builder.default_export`.

Expand Down Expand Up @@ -91,9 +98,6 @@ def default_export(mod: Module) -> str:
please send the registration logic via initializer.
"""

T_BUILD = Callable[[IRModule, Target, Optional[Dict[str, NDArray]]], Module]
T_EXPORT = Callable[[Module], str]

pool: PopenPoolExecutor
timeout_sec: float
f_build: Union[None, str, T_BUILD]
Expand All @@ -117,10 +121,10 @@ def __init__(
Defaults to number of CPUs.
timeout_sec : float
The timeout in seconds for the build.
f_build : LocalBuilder.T_BUILD
f_build : T_BUILD
Name of the build function to be used.
Defaults to `meta_schedule.builder.default_build`.
f_export : LocalBuilder.T_EXPORT
f_export : T_EXPORT
Name of the export function to be used.
Defaults to `meta_schedule.builder.default_export`.
initializer : Optional[Callable[[], None]]
Expand Down Expand Up @@ -148,7 +152,7 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:

# Dispatch the build inputs to the worker processes.
for map_result in self.pool.map_with_error_catching(
lambda x: LocalBuilder._worker_func(*x),
lambda x: _worker_func(*x),
[
(
self.f_build,
Expand Down Expand Up @@ -188,28 +192,28 @@ def _check(f_build, f_export) -> None:
value = self.pool.submit(_check, self.f_build, self.f_export)
value.result()

@staticmethod
def _worker_func(
_f_build: Union[None, str, T_BUILD],
_f_export: Union[None, str, T_EXPORT],
mod: IRModule,
target: Target,
params: Optional[bytearray],
) -> str:
# Step 0. Get the registered functions
f_build: LocalBuilder.T_BUILD = get_global_func_with_default_on_worker(
_f_build,
default_build,
)
f_export: LocalBuilder.T_EXPORT = get_global_func_with_default_on_worker(
_f_export,
default_export,
)
# Step 1. Build the IRModule
rt_mod: Module = f_build(mod, target, _deserialize_params(params))
# Step 2. Export the Module
artifact_path: str = f_export(rt_mod)
return artifact_path

def _worker_func(
_f_build: Union[None, str, T_BUILD],
_f_export: Union[None, str, T_EXPORT],
mod: IRModule,
target: Target,
params: Optional[bytearray],
) -> str:
# Step 0. Get the registered functions
f_build: T_BUILD = get_global_func_with_default_on_worker(
_f_build,
default_build,
)
f_export: T_EXPORT = get_global_func_with_default_on_worker(
_f_export,
default_export,
)
# Step 1. Build the IRModule
rt_mod: Module = f_build(mod, target, _deserialize_params(params))
# Step 2. Export the Module
artifact_path: str = f_export(rt_mod)
return artifact_path


@register_func("meta_schedule.builder.default_build")
Expand Down
118 changes: 91 additions & 27 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
# under the License.
"""Meta Schedule CostModel."""
import ctypes
from typing import List
from typing import Callable, List

import numpy as np # type: ignore
from tvm._ffi import register_object
from tvm.meta_schedule.utils import _get_default_str
from tvm.runtime import Object

from .. import _ffi_api
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import _get_hex_address, check_override


@register_object("meta_schedule.CostModel")
Expand Down Expand Up @@ -99,41 +99,28 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n


@register_object("meta_schedule.PyCostModel")
class PyCostModel(CostModel):
"""An abstract CostModel with customized methods on the python-side."""
class _PyCostModel(CostModel):
"""
A TVM object cost model to support customization on the python side.
This is NOT the user facing class for function overloading inheritance.

def __init__(self):
"""Constructor."""

@check_override(self.__class__, CostModel)
def f_load(path: str) -> None:
self.load(path)

@check_override(self.__class__, CostModel)
def f_save(path: str) -> None:
self.save(path)
See also: PyCostModel
"""

@check_override(self.__class__, CostModel)
def f_update(
context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
self.update(context, candidates, results)
def __init__(self, methods: List[Callable]):
"""Constructor."""

@check_override(self.__class__, CostModel)
def f_predict(context: TuneContext, candidates: List[MeasureCandidate], return_ptr) -> None:
n = len(candidates)
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double))
array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,))
array_wrapper[:] = self.predict(context, candidates)
res = predict_func(context, candidates)
array_wrapper[:] = res
assert (
array_wrapper.dtype == "float64"
), "ValueError: Invalid data type returned from CostModel Predict!"

def f_as_string() -> str:
return str(self)

f_load, f_save, f_update, predict_func, f_as_string = methods
self.__init_handle_by_constructor__(
_ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member
f_load,
Expand All @@ -143,5 +130,82 @@ def f_as_string() -> str:
f_as_string,
)


class PyCostModel:
"""
An abstract cost model with customized methods on the python-side.
This is the user facing class for function overloading inheritance.

Note: @derived_object is required for proper usage of any inherited class.
"""

_tvm_metadata = {
"cls": _PyCostModel,
"methods": ["load", "save", "update", "predict", "__str__"],
}

def load(self, path: str) -> None:
"""Load the cost model from given file location.

Parameters
----------
path : str
The file path.
"""
raise NotImplementedError

def save(self, path: str) -> None:
"""Save the cost model to given file location.

Parameters
----------
path : str
The file path.
"""
raise NotImplementedError

def update(
self,
context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
"""Update the cost model given running results.

Parameters
----------
context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.
results : List[RunnerResult]
The running results of the measure candidates.
"""
raise NotImplementedError

def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
"""Update the cost model given running results.

Parameters
----------
context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.

Return
------
result : np.ndarray
The predicted normalized score.
"""
raise NotImplementedError

def __str__(self) -> str:
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
"""Get the cost model as string with name.

Return
------
result : str
Get the cost model as string with name.
"""
return _get_default_str(self)
Loading