Skip to content

Commit

Permalink
Fix task scheduler, space generator & search strategy.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Feb 24, 2022
1 parent a5df3e5 commit 67d3637
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 129 deletions.
7 changes: 2 additions & 5 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,13 @@ def __init__(self, methods: List[Callable]):

class PyBuilder:
"""
An abstract builder with customized run method on the python-side.
An abstract builder with customized build method on the python-side.
This is the user facing class for function overloading inheritance.
Note: @derived_object is required for proper usage of any inherited class.
"""

tvm_metadata = {
"cls": _PyBuilder,
"methods": ["build"],
}
_tvm_metadata = {"cls": _PyBuilder, "methods": ["build"], "required": {"build"}}

def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
"""Build the given inputs.
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def LocalRunnerFuture(PyRunnerFuture):
...
"""

tvm_metadata = {
_tvm_metadata = {
"cls": RunnerFuture,
"methods": ["done", "result"],
"required": {"done", "result"},
}

def done(self) -> bool:
Expand Down Expand Up @@ -204,9 +205,10 @@ class PyRunner:
Note: @derived_object is required for proper usage of any inherited class.
"""

tvm_metadata = {
_tvm_metadata = {
"cls": _PyRunner,
"methods": ["run"],
"required": {"run"}
}

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
Expand Down
119 changes: 88 additions & 31 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Meta Schedule search strategy that generates the measure
candidates for measurement.
"""
from typing import List, Optional, TYPE_CHECKING
from typing import Callable, List, Optional, TYPE_CHECKING

from tvm._ffi import register_object
from tvm.runtime import Object
Expand Down Expand Up @@ -138,41 +138,98 @@ def notify_runner_results(


@register_object("meta_schedule.PySearchStrategy")
class PySearchStrategy(SearchStrategy):
"""An abstract search strategy with customized methods on the python-side."""
class _PySearchStrategy(SearchStrategy):
"""
A TVM object search strategy to support customization on the python side.
This is NOT the user facing class for function overloading inheritance.
See also: PySearchStrategy
"""

def __init__(self):
def __init__(self, methods: List[Callable]):
"""Constructor."""

@check_override(self.__class__, SearchStrategy)
def f_initialize_with_tune_context(context: "TuneContext") -> None:
self.initialize_with_tune_context(context)
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member
*methods,
)

@check_override(self.__class__, SearchStrategy)
def f_pre_tuning(design_spaces: List[Schedule]) -> None:
self.pre_tuning(design_spaces)

@check_override(self.__class__, SearchStrategy)
def f_post_tuning() -> None:
self.post_tuning()
class PySearchStrategy:
"""
An abstract search strategy with customized methods on the python-side.
This is the user facing class for function overloading inheritance.
@check_override(self.__class__, SearchStrategy)
def f_generate_measure_candidates() -> List[MeasureCandidate]:
return self.generate_measure_candidates()
Note: @derived_object is required for proper usage of any inherited class.
"""

@check_override(self.__class__, SearchStrategy)
def f_notify_runner_results(
context: "TuneContext",
measure_candidates: List[MeasureCandidate],
results: List["RunnerResult"],
) -> None:
self.notify_runner_results(context, measure_candidates, results)
_tvm_metadata = {
"cls": _PySearchStrategy,
"methods": [
"initialize_with_tune_context",
"pre_tuning",
"post_tuning",
"generate_measure_candidates",
"notify_runner_results",
],
"required": {
"initialize_with_tune_context",
"pre_tuning",
"post_tuning",
"generate_measure_candidates",
"notify_runner_results",
},
}

self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_pre_tuning,
f_post_tuning,
f_generate_measure_candidates,
f_notify_runner_results,
)
def initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the search strategy with tuning context.
Parameters
----------
context : TuneContext
The tuning context for initialization.
"""
raise NotImplementedError

def pre_tuning(self, design_spaces: List[Schedule]) -> None:
"""Pre-tuning for the search strategy.
Parameters
----------
design_spaces : List[Schedule]
The design spaces for pre-tuning.
"""
raise NotImplementedError

def post_tuning(self) -> None:
"""Post-tuning for the search strategy."""
raise NotImplementedError

def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]:
"""Generate measure candidates from design spaces for measurement.
Returns
-------
measure_candidates : Optional[List[IRModule]]
The measure candidates generated, None if finished.
"""
raise NotImplementedError

def notify_runner_results(
self,
context: "TuneContext",
measure_candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
"""Update the search strategy with profiling results.
Parameters
----------
context : TuneContext
The tuning context for update.
measure_candidates : List[MeasureCandidate]
The measure candidates for update.
results : List[RunnerResult]
The profiling results from the runner.
"""
raise NotImplementedError
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/space_generator/schedule_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tvm.ir import IRModule
from tvm.ir.container import Array
from tvm.meta_schedule.utils import derived_object
from tvm.tir.schedule import Schedule

from .space_generator import PySpaceGenerator
Expand All @@ -30,6 +31,7 @@
from ..tune_context import TuneContext


@derived_object
class ScheduleFn(PySpaceGenerator):
"""A design space generator with design spaces specified by a schedule function."""

Expand Down
64 changes: 50 additions & 14 deletions python/tvm/meta_schedule/space_generator/space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Meta Schedule design space generators that generates design
space for generation of measure candidates.
"""
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Callable, List

from tvm._ffi import register_object
from tvm.ir import IRModule
Expand Down Expand Up @@ -65,22 +65,58 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]:


@register_object("meta_schedule.PySpaceGenerator")
class PySpaceGenerator(SpaceGenerator):
"""An abstract design space generator with customized methods on the python-side."""
class _PySpaceGenerator(SpaceGenerator):
"""
A TVM object space generator to support customization on the python side.
This is NOT the user facing class for function overloading inheritance.
def __init__(self):
"""Constructor."""

@check_override(self.__class__, SpaceGenerator)
def f_initialize_with_tune_context(context: "TuneContext") -> None:
self.initialize_with_tune_context(context)
See also: PySpaceGenerator
"""

@check_override(self.__class__, SpaceGenerator)
def f_generate_design_space(mod: IRModule) -> List[Schedule]:
return self.generate_design_space(mod)
def __init__(self, methods: List[Callable]):
"""Constructor."""

self.__init_handle_by_constructor__(
_ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_generate_design_space,
*methods,
)


class PySpaceGenerator:
"""
An abstract space generator with customized methods on the python-side.
This is the user facing class for function overloading inheritance.
Note: @derived_object is required for proper usage of any inherited class.
"""

_tvm_metadata = {
"cls": _PySpaceGenerator,
"methods": ["initialize_with_tune_context", "generate_design_space"],
"required": {"initialize_with_tune_context", "generate_design_space"},
}

def initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the design space generator with tuning context.
Parameters
----------
context : TuneContext
The tuning context for initializing the design space generator.
"""
raise NotImplementedError

def generate_design_space(self, mod: IRModule) -> List[Schedule]:
"""Generate design spaces given a module.
Parameters
----------
mod : IRModule
The module used for design space generation.
Returns
-------
design_spaces : List[Schedule]
The generated design spaces, i.e., schedules.
"""
raise NotImplementedError
Loading

0 comments on commit 67d3637

Please sign in to comment.