Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M3a] Traced Schedule #8623

Merged
merged 1 commit into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -95,13 +96,15 @@ class ScheduleNode : public runtime::Object {
virtual ~ScheduleNode() = default;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object);

public:
/*! \brief Get the IRModule associated with this schedule. */
virtual IRModule mod() const { return state()->mod; }
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
Expand Down Expand Up @@ -288,7 +291,7 @@ class Schedule : public runtime::ObjectRef {
/*!
* \brief Construct a concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
Expand All @@ -297,8 +300,22 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode,
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed include:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enum ScheduleDebugMask : uint32_t {
* 2) The sref tree of schedulable statements (indicated by the srefs)
* 3) The dependency information of each block scope (block_info)
* 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
* 5) A debug flag, if set, extra checking is enabled (debug_mode)
* 5) A debug flag, if set, extra checking is enabled (debug_mask)
*/
class ScheduleStateNode : public Object {
public:
Expand All @@ -99,13 +99,13 @@ class ScheduleStateNode : public Object {
* and each time after calling the Replace method.
* \sa ScheduleDebugMask
*/
int debug_mode;
int debug_mask;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mod", &mod);
// `block_info` is not visited
// `stmt2ref` is not visited
v->Visit("debug_mode", &debug_mode);
v->Visit("debug_mask", &debug_mask);
}
/*!
* \brief Replace the part of the AST, as being pointed to by `src_sref`,
Expand All @@ -129,7 +129,7 @@ class ScheduleStateNode : public Object {
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
const Map<Block, Block>& block_sref_reuse);
/*!
* \brief Trigger the verification according to the `debug_mode` bitmask.
* \brief Trigger the verification according to the `debug_mask` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
* 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`,
* `region_cover` and `stage_pipeline`
Expand Down Expand Up @@ -186,10 +186,10 @@ class ScheduleState : public ObjectRef {
/*!
* \brief Construct a schedule state from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
*/
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0);
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);

/*! \return The mutable pointer to the ScheduleStateNode */
ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
Expand Down
96 changes: 54 additions & 42 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand All @@ -25,7 +24,8 @@
from tvm.tir import Block, For, IntImm, PrimFunc

from . import _ffi_api
from .state import ScheduleState, StmtSRef
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
from .trace import Trace


@register_error
Expand Down Expand Up @@ -63,7 +63,20 @@ def __init__(self) -> None:
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name

# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
ERROR_RENDER_LEVEL_CANDIDATES = Union[str] # pylint: disable=invalid-name
_ERROR_RENDER_LEVEL: Dict[str, int] = {
"detail": 0,
"fast": 1,
"none": 2,
}


def _parse_error_render_level(error_render_level: str) -> int:
if error_render_level not in _ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
return _ERROR_RENDER_LEVEL.get(error_render_level)


@_register_object("tir.Schedule")
Expand All @@ -81,73 +94,77 @@ class Schedule(Object):
Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
"""

ERROR_RENDER_LEVEL = {
"detail": 0,
"fast": 1,
"none": 2,
}

def __init__(
self,
mod: Union[PrimFunc, IRModule],
*,
debug_mode: Union[bool, int] = False,
error_render_level: ERROR_RENDER_LEVEL_CANDIDATES = "detail",
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> None:
"""Construct a concrete TensorIR schedule from an IRModule or a PrimFunc
"""Construct a TensorIR schedule class from an IRModule

Parameters
----------
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to be scheduled
debug_mode : Union[bool, int]
debug_mask : Union[str, int]
Do extra correctness checking after the class creation and each time
scheduling primitive
after calling the Replace method.
Possible choices of `debug_mask`:
1) "all" - Turn on all the checks
2) "none" - Turn off all the checks
3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
error_render_level : str = "detail"
The level of error rendering. Choices: "detail", "fast", "none".
"detail": Render a detailed error message, with the TIR and error locations printed
"fast: Show a simple error message without rendering or string manipulation
"none": Do not show any error message.
- "detail": Render a detailed error message, with the TIR and error locations printed
- "fast: Show a simple error message without rendering or string manipulation
- "none": Do not show any error message.

Note
----
The checks performed includes:
1) VerifySRefTree
2) VerifyCachedFlags
"""
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
if error_render_level not in Schedule.ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
# call the constructor
self.__init_handle_by_constructor__(
_ffi_api.ConcreteSchedule, # type: ignore # pylint: disable=no-member
mod,
debug_mode,
Schedule.ERROR_RENDER_LEVEL.get(error_render_level),
_ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)

@staticmethod
def _create_non_traced(
mod: Union[PrimFunc, IRModule],
*,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)

########## Utilities ##########

@property
def mod(self) -> IRModule:
"""Returns the AST of the module being scheduled"""
return _ffi_api.ScheduleModule(self) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member
junrushao marked this conversation as resolved.
Show resolved Hide resolved

@property
def state(self) -> ScheduleState:
"""Returns the ScheduleState in the current schedule class"""
return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member

@property
def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
comaniac marked this conversation as resolved.
Show resolved Hide resolved
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

def copy(self) -> "Schedule":
"""Returns a copy of the schedule, including both the state and the symbol table,
* guaranteeing that
Expand Down Expand Up @@ -702,8 +719,3 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
def enter_postproc(self) -> None:
"""A no-op that marks the start of postprocessing phase of scheduling"""
_ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member


@_register_object("tir.ConcreteSchedule")
class ConcreteSchedule(Schedule):
"""A concrete schedule class of TensorIR. Do not use directly, use tvm.tir.Schedule instead."""
57 changes: 35 additions & 22 deletions python/tvm/tir/schedule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@


class ScheduleDebugMask(IntEnum):
"""The bitmask of the `debug_mode` flag in the ScheduleState class.
"""The bitmask of the `debug_mask` flag in the ScheduleState class.

If the `debug_mode` flag has a certain bit on, then the correpsonding
verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`,
If the `debug_mask` flag has a certain bit on, then the correpsonding
verification pass will be conducted. For example, if `(debug_mask & VERIFY_SREF_TREE) != 0`,
then the correctness of the sref tree will be verified after each schedule instruction.

Attributes
Expand All @@ -49,6 +49,27 @@ class ScheduleDebugMask(IntEnum):
VERIFY_CACHED_FLAGS = 2


def _parse_mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if not isinstance(mod, IRModule):
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
return mod


def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
if isinstance(debug_mask, str):
if debug_mask == "all":
debug_mask = ScheduleDebugMask.VERIFY_SREF_TREE | ScheduleDebugMask.VERIFY_CACHED_FLAGS
elif debug_mask == "none":
debug_mask = 0
else:
raise ValueError(f"Unrecognizable `debug_mask`: {debug_mask}")
if isinstance(debug_mask, bool) or not isinstance(debug_mask, int):
raise TypeError(f"`debug_mask` should be integer or boolean, but gets: {debug_mask}")
return debug_mask


@register_object("tir.ScheduleState")
class ScheduleState(Object):
"""The state of scheduling, which exposes a `Replace` method as
Expand All @@ -59,52 +80,44 @@ class ScheduleState(Object):
2) The sref tree of schedulable statements (indicated by the srefs)
3) The dependency information of each block scope (block_info)
4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
5) A debug flag, if set, extra checking is enabled (debug_mode)
5) A debug flag, if set, extra checking is enabled (debug_mask)

Parameters
----------
mod : IRModule
The AST of the module being scheduled
debug_mode : int
debug_mask : int
Do extra correctness checking after the object construction
and each time after calling the Replace method.
"""

mod: IRModule
debug_mode: int
debug_mask: int

def __init__(
self,
mod: Union[PrimFunc, IRModule],
debug_mode: Union[bool, int] = False,
*,
debug_mask: Union[str, int] = "none",
) -> None:
"""Construct a schedule state from an IRModule or a PrimFunc

Parameters
----------
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to be scheduled
debug_mode : Union[bool, int]
debug_mask : Union[str, int]
Do extra correctness checking after the class creation and each time
after calling the Replace method.
Possible choices of `debug_mode`:
1) True - Turn on all the checks
2) False - Turn off all the checks
Possible choices of `debug_mask`:
1) "all" - Turn on all the checks
2) "none" - Turn off all the checks
3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
"""
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
self.__init_handle_by_constructor__(
_ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member
mod,
debug_mode,
_parse_mod(mod),
_parse_debug_mask(debug_mask),
)

def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
Expand Down
Loading