Skip to content

Commit

Permalink
feat: add ruff integration (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Feb 19, 2023
1 parent a112b5b commit 0a71edc
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 47 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ jobs:
run: |
make pre-commit
- name: ruff
run: |
make ruff
- name: flake8
run: |
make flake8
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ repos:
hooks:
- id: clang-format
stages: [commit, push, manual]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.247
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
stages: [commit, push, manual]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#33](https://github.com/metaopt/optree/pull/33) and [#34](https://github.com/metaopt/optree/pull/34).

### Changed

Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ make uninstall

We use several tools to secure code quality, including:

- PEP8 code style: `black`, `isort`, `pylint`, `flake8`
- Python code style: `black`, `isort`, `pylint`, `flake8`, `ruff`
- Type hint check: `mypy`
- C++ code style: `cpplint`, `clang-format`, `clang-tidy`
- License: `addlicense`
Expand Down
14 changes: 12 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ py-format-install:
$(call check_pip_install,isort)
$(call check_pip_install,black)

ruff-install:
$(call check_pip_install,ruff)

mypy-install:
$(call check_pip_install,mypy)

Expand Down Expand Up @@ -132,6 +135,12 @@ py-format: py-format-install
$(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \
$(PYTHON) -m black --check $(PYTHON_FILES)

ruff: ruff-install
$(PYTHON) -m ruff check .

ruff-fix: ruff-install
$(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix

mypy: mypy-install
$(PYTHON) -m mypy $(PROJECT_PATH)

Expand Down Expand Up @@ -187,11 +196,12 @@ clean-docs:

# Utility functions

lint: flake8 py-format mypy pylint doctest clang-format clang-tidy cpplint addlicense docstyle spelling
lint: ruff flake8 py-format mypy pylint doctest clang-format clang-tidy cpplint addlicense docstyle spelling

format: py-format-install clang-format-install addlicense-install
format: py-format-install ruff-install clang-format-install addlicense-install
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES)
$(PYTHON) -m ruff check . --fix --exit-zero
$(CLANG_FORMAT) -style=file -i $(CXX_FILES)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS)

Expand Down
2 changes: 1 addition & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def fn3(x, y, z):
)


def cprint(text=''):
def cprint(text: str = '') -> None:
text = (
text.replace(
', none_is_leaf=False',
Expand Down
1 change: 1 addition & 0 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ dependencies:
- flake8-docstrings
- flake8-pyi
- flake8-simplify
- ruff
- doc8
- pydocstyle
- xdoctest
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

def get_version() -> str:
sys.path.insert(0, str(PROJECT_ROOT / 'optree'))
import version # noqa
import version

return version.__version__

Expand Down
34 changes: 22 additions & 12 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,31 @@ MAX_RECURSION_DEPTH: int

def flatten(
tree: PyTree[T],
leaf_predicate: Callable[[T], bool] | None = None,
node_is_leaf: bool = False,
namespace: str = '',
leaf_predicate: Callable[[T], bool] | None = ..., # None
node_is_leaf: bool = ..., # False
namespace: str = ..., # ''
) -> builtins.tuple[list[T], PyTreeSpec]: ...
def flatten_with_path(
tree: PyTree[T],
leaf_predicate: Callable[[T], bool] | None = None,
node_is_leaf: bool = False,
namespace: str = '',
leaf_predicate: Callable[[T], bool] | None = ..., # None
node_is_leaf: bool = ..., # False
namespace: str = ..., # ''
) -> builtins.tuple[list[builtins.tuple[Any, ...]], list[T], PyTreeSpec]: ...
def all_leaves(
iterable: Iterable[T],
node_is_leaf: bool = False,
namespace: str = '',
node_is_leaf: bool = ..., # False
namespace: str = ..., # ''
) -> bool: ...
def leaf(node_is_leaf: bool = False) -> PyTreeSpec: ...
def none(node_is_leaf: bool = False) -> PyTreeSpec: ...
def tuple(treespecs: Sequence[PyTreeSpec], node_is_leaf: bool = False) -> PyTreeSpec: ...
def leaf(
node_is_leaf: bool = ..., # False
) -> PyTreeSpec: ...
def none(
node_is_leaf: bool = ..., # False
) -> PyTreeSpec: ...
def tuple(
treespecs: Sequence[PyTreeSpec],
node_is_leaf: bool = ..., # False
) -> PyTreeSpec: ...
def is_namedtuple_class(cls: type) -> bool: ...
def is_structseq_class(cls: type) -> bool: ...
def structseq_fields(obj: object | type) -> builtins.tuple[str]: ...
Expand All @@ -66,7 +73,10 @@ class PyTreeSpec:
leaves: Iterable[T],
) -> U: ...
def children(self) -> list[PyTreeSpec]: ...
def is_leaf(self, strict: bool = True) -> bool: ...
def is_leaf(
self,
strict: bool = ..., # True
) -> bool: ...
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
Expand Down
9 changes: 3 additions & 6 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections import deque
from typing import Any, Callable, cast, overload

import optree._C as _C
from optree import _C
from optree.registry import (
AttributeKeyPathEntry,
FlattenedKeyPathEntry,
Expand Down Expand Up @@ -1157,10 +1157,7 @@ def flatten_one_level(
)
children, metadata, entries = flattened
children = list(children)
if entries is None:
entries = tuple(range(len(children)))
else:
entries = tuple(entries)
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
Expand Down Expand Up @@ -1300,7 +1297,7 @@ def _child_keys(
namespace: str = '',
) -> list[KeyPathEntry]:
treespec = tree_structure(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
assert not treespec_is_strict_leaf(treespec)
assert not treespec_is_strict_leaf(treespec), 'treespec must be a non-leaf node'

handler = register_keypaths.get(type(tree)) # type: ignore[attr-defined]
if handler:
Expand Down
12 changes: 6 additions & 6 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from threading import Lock
from typing import Any, Callable, Iterable, NamedTuple, Sequence, overload

import optree._C as _C
from optree import _C
from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, PyTree, T, UnflattenFunc
from optree.utils import safe_zip, unzip2

Expand Down Expand Up @@ -387,7 +387,7 @@ class _HashablePartialShim:
def __init__(self, partial_func: functools.partial) -> None:
self.partial_func: functools.partial = partial_func

def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.partial_func(*args, **kwargs)

def __hash__(self) -> int:
Expand Down Expand Up @@ -449,7 +449,7 @@ class Partial(functools.partial, CustomTreeNode[Any]): # pylint: disable=too-fe
args: tuple[Any, ...]
keywords: dict[str, Any]

def __new__(cls, func: Callable[..., Any], *args, **keywords) -> Partial:
def __new__(cls, func: Callable[..., Any], *args: Any, **keywords: Any) -> Partial:
"""Create a new :class:`Partial` instance."""
# In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__
# would merge the arguments of this Partial instance with the arguments of the func. We box
Expand All @@ -458,7 +458,7 @@ def __new__(cls, func: Callable[..., Any], *args, **keywords) -> Partial:
if isinstance(func, functools.partial):
original_func = func
func = _HashablePartialShim(original_func)
assert not hasattr(func, 'func')
assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute'
out = super().__new__(cls, func, *args, **keywords)
func.func = original_func.func
func.args = original_func.args
Expand Down Expand Up @@ -488,7 +488,7 @@ def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath((self, other))
if isinstance(other, KeyPath):
return KeyPath((self,) + other.keys)
return KeyPath((self, *other.keys))
return NotImplemented

def __eq__(self, other: object) -> bool:
Expand All @@ -504,7 +504,7 @@ class KeyPath(NamedTuple):

def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath(self.keys + (other,))
return KeyPath((*self.keys, other))
if isinstance(other, KeyPath):
return KeyPath(self.keys + other.keys)
return NotImplemented
Expand Down
35 changes: 21 additions & 14 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Hashable,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Expand All @@ -39,7 +40,7 @@
from typing_extensions import Protocol # Python 3.8+
from typing_extensions import TypeAlias # Python 3.10+

import optree._C as _C
from optree import _C


try:
Expand Down Expand Up @@ -118,18 +119,18 @@ def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTree
_GenericAlias = type(Union[int, str])


def _tp_cache(func):
def _tp_cache(func: Callable) -> Callable:
import functools # pylint: disable=import-outside-toplevel

cached = functools.lru_cache()(func)

@functools.wraps(func)
def inner(*args, **kwds):
try: # noqa: SIM105
def inner(*args: Any, **kwds: Any) -> Any:
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
# All real errors (not unhashable args) are raised below.
return func(*args, **kwds)

return inner

Expand All @@ -150,7 +151,9 @@ class PyTree(Generic[T]): # pylint: disable=too-few-public-methods
"""

@_tp_cache
def __class_getitem__(cls, item: T | tuple[T] | tuple[T, str | None]) -> TypeAlias:
def __class_getitem__( # noqa: C901
cls, item: T | tuple[T] | tuple[T, str | None]
) -> TypeAlias:
"""Instantiate a PyTree type with the given type."""
if not isinstance(item, tuple):
item = (item, None)
Expand Down Expand Up @@ -200,15 +203,19 @@ def __class_getitem__(cls, item: T | tuple[T] | tuple[T, str | None]) -> TypeAli
pytree_alias.__pytree_args__ = item # type: ignore[attr-defined]
return pytree_alias

def __init_subclass__(cls, *args, **kwargs):
def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ
"""Prohibit instantiation."""
raise TypeError('Cannot instantiate special typing classes.')

def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')

def __copy__(self):
def __copy__(self) -> PyTree:
"""Immutable copy."""
return self

def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict[int, Any]) -> PyTree:
"""Immutable copy."""
return self

Expand All @@ -235,15 +242,15 @@ def __new__(cls, name: str, param: type) -> TypeAlias:
raise TypeError(f'{cls.__name__} only supports a string of type name. Got {name!r}.')
return PyTree[param, name] # type: ignore[misc,valid-type]

def __init_subclass__(cls, *args, **kwargs):
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')

def __copy__(self):
def __copy__(self) -> TypeAlias:
"""Immutable copy."""
return self

def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias:
"""Immutable copy."""
return self

Expand Down Expand Up @@ -296,5 +303,5 @@ class SubClass(cls): # type: ignore[misc,valid-type]


# Ensure that the behavior is consistent with C++ implementation
# pylint: disable-next=wrong-import-position
# pylint: disable-next=wrong-import-position,ungrouped-imports
from optree._C import is_namedtuple_class, is_structseq_class, structseq_fields
Loading

0 comments on commit 0a71edc

Please sign in to comment.