Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Fix Cyclic Dependency in PyClass Family #10368

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Staging changes.
  • Loading branch information
zxybazh committed Mar 2, 2022
commit 8ad4c32197029fae73bc2b80d35b3c62857ab6b4
10 changes: 8 additions & 2 deletions include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,18 @@ class RunnerFutureNode : public runtime::Object {
* \brief Check whether the runner has finished.
* \return A boolean indicating whether the runner has finished.
*/
bool Done() const { return f_done(); }
bool Done() const {
ICHECK(f_done != nullptr) << "PyRunnerFuture's Done method not implemented!";
return f_done();
}
/*!
* \brief Fetch the runner's output if it is ready.
* \return The runner's output.
*/
RunnerResult Result() const { return f_result(); }
RunnerResult Result() const {
ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!";
return f_result();
}

static constexpr const char* _type_key = "meta_schedule.RunnerFuture";
TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
<< "PySpaceGenerator's InitializeWithTuneContext method not implemented!";
f_initialize_with_tune_context(context);
}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class PyBuilder:
Note: @derived_object is required for proper usage of any inherited class.
"""

_tvm_metadata = {"cls": _PyBuilder, "methods": ["build"], "required": {"build"}}
_tvm_metadata = {"cls": _PyBuilder, "methods": ["build"]}

def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
"""Build the given inputs.
Expand Down
26 changes: 20 additions & 6 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np # type: ignore
from tvm._ffi import register_object
from tvm.meta_schedule.utils import get_hex_address
from tvm.runtime import Object

from .. import _ffi_api
Expand Down Expand Up @@ -119,10 +120,7 @@ def f_predict(context: TuneContext, candidates: List[MeasureCandidate], return_p
array_wrapper.dtype == "float64"
), "ValueError: Invalid data type returned from CostModel Predict!"

def f_as_string():
return self.__str__()

f_load, f_save, f_update, predict_func = methods
f_load, f_save, f_update, predict_func, f_as_string = methods
self.__init_handle_by_constructor__(
_ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member
f_load,
Expand All @@ -143,8 +141,7 @@ class PyCostModel:

_tvm_metadata = {
"cls": _PyCostModel,
"methods": ["load", "save", "update", "predict"],
"required": {"load", "save", "update", "predict"},
"methods": ["load", "save", "update", "predict", "__str__"],
}

def load(self, path: str) -> None:
Expand Down Expand Up @@ -202,3 +199,20 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n
The predicted normalized score.
"""
raise NotImplementedError

def __str__(self) -> str:
"""Update the cost model given running results.

Parameters
----------
context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.

Return
------
result : np.ndarray
The predicted normalized score.
"""
return f"meta_schedule.{self.__class__.__name__}({get_hex_address(self.handle)})"
7 changes: 0 additions & 7 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,6 @@ class PyDatabase:
"get_top_k",
"__len__",
],
"required": {
"has_workload",
"commit_workload",
"commit_tuning_record",
"get_top_k",
"__len__",
},
}

def has_workload(self, mod: IRModule) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.runtime.ndarray import NDArray

from .. import _ffi_api
from ..utils import _get_hex_address, check_override
from ..utils import get_hex_address, check_override
from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate

Expand Down Expand Up @@ -78,4 +78,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({get_hex_address(self.handle)})"
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..builder import BuilderResult
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..utils import _get_hex_address, check_override
from ..utils import get_hex_address, check_override

if TYPE_CHECKING:
from ..task_scheduler import TaskScheduler
Expand Down Expand Up @@ -101,4 +101,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"PyMeasureCallback({_get_hex_address(self.handle)})"
return f"PyMeasureCallback({get_hex_address(self.handle)})"
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.tir.schedule import Trace

from .. import _ffi_api
from ..utils import _get_hex_address, check_override
from ..utils import get_hex_address, check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down Expand Up @@ -85,4 +85,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({get_hex_address(self.handle)})"
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.tir.schedule import Schedule

from .. import _ffi_api
from ..utils import _get_hex_address, check_override
from ..utils import get_hex_address, check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down Expand Up @@ -87,4 +87,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({get_hex_address(self.handle)})"
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def LocalRunnerFuture(PyRunnerFuture):
_tvm_metadata = {
"cls": RunnerFuture,
"methods": ["done", "result"],
"required": {"done", "result"},
}

def done(self) -> bool:
Expand Down Expand Up @@ -204,7 +203,10 @@ class PyRunner:
Note: @derived_object is required for proper usage of any inherited class.
"""

_tvm_metadata = {"cls": _PyRunner, "methods": ["run"], "required": {"run"}}
_tvm_metadata = {
"cls": _PyRunner,
"methods": ["run"],
}

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
"""Run the built artifact and get runner futures.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.runtime import Object
from tvm.tir.schedule import Schedule, BlockRV

from ..utils import _get_hex_address, check_override
from ..utils import get_hex_address, check_override
from .. import _ffi_api

if TYPE_CHECKING:
Expand Down Expand Up @@ -93,4 +93,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({get_hex_address(self.handle)})"
7 changes: 0 additions & 7 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,6 @@ class PySearchStrategy:
"generate_measure_candidates",
"notify_runner_results",
],
"required": {
"initialize_with_tune_context",
"pre_tuning",
"post_tuning",
"generate_measure_candidates",
"notify_runner_results",
},
}

def initialize_with_tune_context(self, context: "TuneContext") -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ class PySpaceGenerator:
_tvm_metadata = {
"cls": _PySpaceGenerator,
"methods": ["initialize_with_tune_context", "generate_design_space"],
"required": {"initialize_with_tune_context", "generate_design_space"},
}

def initialize_with_tune_context(self, context: "TuneContext") -> None:
Expand Down
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ class PyTaskScheduler:
"join_running_task",
"next_task_id",
],
"required": {"next_task_id"},
}

def __init__(
Expand Down
24 changes: 10 additions & 14 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

def derived_object(cls: Any) -> type:
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
"""A decorator to register derived subclasses for TVM objects.

Parameters
----------
cls : type
Expand All @@ -55,8 +56,7 @@ def __init__(self, methods) -> None:
class PyRunner():
_tvm_metadata = {
"cls": _PyRunner,
"methods": ["run"],
"required": {"run"}
"methods": ["run"]
}
def run(self, runner_inputs):
raise NotImplementedError
Expand All @@ -70,17 +70,16 @@ def run(self, runner_inputs):
import functools # pylint: disable=import-outside-toplevel
import weakref # pylint: disable=import-outside-toplevel

def _extract(inst: Any, name: str, required: Set[str]):
def _extract(inst: Any, name: str):
"""Extract function from intrinsic class."""

def method(*args, **kwargs):
return getattr(inst, name)(*args, **kwargs)

if getattr(base, name) is getattr(cls, name):
# return nullptr to use default function on the c++ side
# for task scheduler use only
if name in required:
raise NotImplementedError(f"{cls}'s {name} method is not implemented!")
if getattr(base, name) is getattr(cls, name) and name != "__str__":
# for task scheduler return None means calling default function
# otherwise it will trigger a TVMError of method not implemented
# on the c++ side when you call the method, __str__ not required
return None
return method

Expand All @@ -93,7 +92,6 @@ def method(*args, **kwargs):
metadata = getattr(base, "_tvm_metadata")
members = metadata.get("members", [])
methods = metadata.get("methods", [])
required = metadata.get("required", {})

class TVMDerivedObject(metadata["cls"]): # type: ignore
"""The derived object to avoid cyclic dependency."""
Expand All @@ -110,8 +108,9 @@ def __init__(self, *args, **kwargs):
# the constructor's parameters, builder, runner, etc.
*[getattr(self._inst, name) for name in members],
# the function methods, init_with_tune_context, build, run, etc.
[_extract(self._inst, name, required) for name in methods],
[_extract(self._inst, name) for name in methods],
)
self._inst.handle = self.handle

def __getattr__(self, name: str):
"""Bridge the attribute function."""
Expand All @@ -123,9 +122,6 @@ def __setattr__(self, name, value):
else:
super(TVMDerivedObject, self).__setattr__(name, value)

def __str__(self) -> str:
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"

functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__)
TVMDerivedObject.__name__ = cls.__name__
TVMDerivedObject.__doc__ = cls.__doc__
Expand Down Expand Up @@ -382,7 +378,7 @@ def inner(func: Callable):
return inner


def _get_hex_address(handle: ctypes.c_void_p) -> str:
def get_hex_address(handle: ctypes.c_void_p) -> str:
"""Get the hexadecimal address of a handle.
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_meta_schedule_measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.search_strategy import MeasureCandidate
from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler
from tvm.meta_schedule.utils import _get_hex_address
from tvm.meta_schedule.utils import get_hex_address
from tvm.script import tir as T
from tvm.tir.schedule import Schedule

Expand Down Expand Up @@ -119,7 +119,7 @@ def apply(
pass

def __str__(self) -> str:
return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})"
return f"NotSoFancyMeasureCallback({get_hex_address(self.handle)})"

measure_callback = NotSoFancyMeasureCallback()
pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)")
Expand Down
11 changes: 8 additions & 3 deletions tests/python/unittest/test_meta_schedule_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@

import tvm
from tvm.meta_schedule.utils import derived_object
from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion
from tvm.meta_schedule.tune_context import TuneContext
from tvm._ffi.base import TVMError
from tvm.script import tir as T
from tvm.tir.schedule import Schedule
from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down Expand Up @@ -91,8 +93,11 @@ def test_meta_schedule_design_space_generator_NIE():
class TestPySpaceGenerator(PySpaceGenerator):
pass

with pytest.raises(NotImplementedError):
TestPySpaceGenerator()
with pytest.raises(
TVMError, match="PySpaceGenerator's InitializeWithTuneContext method not implemented!"
):
generator = TestPySpaceGenerator()
generator.initialize_with_tune_context(TuneContext())


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_meta_schedule_task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest
import tvm
from tvm._ffi.base import TVMError
from tvm.ir import IRModule
from tvm.meta_schedule import TuneContext, measure_callback
from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder
Expand Down Expand Up @@ -262,8 +263,9 @@ def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name
class MyTaskScheduler(PyTaskScheduler):
pass

with pytest.raises(NotImplementedError):
MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase())
with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"):
scheduler = MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase())
scheduler.next_task_id()


def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name
Expand Down