Skip to content

Commit

Permalink
Fix linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Feb 24, 2022
1 parent 67d3637 commit 70a8c34
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 15 deletions.
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


T_BUILD = Callable[[IRModule, Target, Optional[Dict[str, NDArray]]], Module]
T_EXPORT = Callable[[Module], str]
T_BUILD = Callable[ # pylint: disable=invalid-name
[IRModule, Target, Optional[Dict[str, NDArray]]], Module
]
T_EXPORT = Callable[[Module], str] # pylint: disable=invalid-name


def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]:
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


T_ALLOC_ARGUMENT = Callable[
T_ALLOC_ARGUMENT = Callable[ # pylint: disable=invalid-name
[
Device, # The device on the remote
T_ARG_INFO_JSON_OBJ_LIST, # The metadata information of the arguments to be allocated
int, # The number of repeated allocations to be done
],
List[T_ARGUMENT_LIST], # A list of argument lists
]
T_RUN_EVALUATOR = Callable[
T_RUN_EVALUATOR = Callable[ # pylint: disable=invalid-name
[
Module, # The Module opened on the remote
Device, # The device on the remote
Expand All @@ -53,7 +53,7 @@
],
List[float], # A list of running time
]
T_CLEANUP = Callable[
T_CLEANUP = Callable[ # pylint: disable=invalid-name
[],
None,
]
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


T_CREATE_SESSION = Callable[
T_CREATE_SESSION = Callable[ # pylint: disable=invalid-name
[RPCConfig], # The RPC configuration
RPCSession, # The RPC Session
]
T_UPLOAD_MODULE = Callable[
T_UPLOAD_MODULE = Callable[ # pylint: disable=invalid-name
[
RPCSession, # The RPC Session
str, # local path to the artifact
str, # remote path to the artifact
],
Module, # the Module opened on the remote
]
T_ALLOC_ARGUMENT = Callable[
T_ALLOC_ARGUMENT = Callable[ # pylint: disable=invalid-name
[
RPCSession, # The RPC Session
Device, # The device on the remote
Expand All @@ -64,7 +64,7 @@
],
List[T_ARGUMENT_LIST], # A list of argument lists
]
T_RUN_EVALUATOR = Callable[
T_RUN_EVALUATOR = Callable[ # pylint: disable=invalid-name
[
RPCSession, # The RPC Session
Module, # The Module opened on the remote
Expand All @@ -74,7 +74,7 @@
],
List[float], # A list of running time
]
T_CLEANUP = Callable[
T_CLEANUP = Callable[ # pylint: disable=invalid-name
[
Optional[RPCSession], # The RPC Session to be cleaned up
Optional[str], # remote path to the artifact
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from .. import _ffi_api
from ..arg_info import ArgInfo
from ..runner import RunnerResult
from ..utils import check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tvm.tir.schedule import Schedule

from .. import _ffi_api
from ..utils import check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down
12 changes: 9 additions & 3 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm.tir import FloatImm, IntImm


def derived_object(cls) -> type:
def derived_object(cls: Any) -> type:
"""A decorator to register derived subclasses for TVM objects.
Parameters
----------
Expand All @@ -56,6 +56,7 @@ class PyRunner():
_tvm_metadata = {
"cls": _PyRunner,
"methods": ["run"],
"required": {"run"}
}
def run(self, runner_inputs):
raise NotImplementedError
Expand All @@ -69,6 +70,8 @@ def run(self, runner_inputs):
import functools # pylint: disable=import-outside-toplevel

def _extract(inst: Any, name: str, required: Set[str]):
"""Extract function from intrinsic class."""

def method(*args, **kwargs):
return getattr(inst, name)(*args, **kwargs)

Expand All @@ -92,8 +95,10 @@ def method(*args, **kwargs):
required = metadata.get("required", {})

class TVMDerivedObject(metadata["cls"]): # type: ignore
def __init__(self, *args, **kwargs):
"""The derived object to avoid cyclic dependency."""

def __init__(self, *args, **kwargs):
"""Constructor."""
self.handle = None
self._inst = cls(*args, **kwargs)
# make sure the inner class can access the outside
Expand All @@ -106,7 +111,8 @@ def __init__(self, *args, **kwargs):
[_extract(self._inst, name, required) for name in methods],
)

def __getattr__(self, name):
def __getattr__(self, name: str):
"""Bridge the attribute function."""
return self._inst.__getattribute__(name)

functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__)
Expand Down

0 comments on commit 70a8c34

Please sign in to comment.