Skip to content

Commit

Permalink
FEAT-modin-project#7146: Use BaseQueryCompiler, BasePandasDataset, Da…
Browse files Browse the repository at this point in the history
…taFrame or Series type hints at a high level (modin-project#7147)

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
  • Loading branch information
anmyachev authored Apr 4, 2024
1 parent fa6e02a commit c1865b8
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 32 deletions.
2 changes: 1 addition & 1 deletion modin/core/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pandas._libs.lib import no_default
from pandas.util._decorators import doc

from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler
from modin.core.storage_formats import BaseQueryCompiler
from modin.db_conn import ModinDatabaseConnection
from modin.error_message import ErrorMessage
from modin.pandas.io import ExcelFile
Expand Down
2 changes: 1 addition & 1 deletion modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
ModinIndex,
extract_dtype,
)
from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler
from modin.core.storage_formats import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.logging import get_logger
from modin.utils import (
Expand Down
4 changes: 2 additions & 2 deletions modin/distributed/dataframe/pandas/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def unwrap_partitions(
f"Only API Layer objects may be passed in here, got {type(api_layer_object)} instead."
)

modin_frame = api_layer_object._query_compiler._modin_frame
modin_frame = api_layer_object._query_compiler._modin_frame # type: ignore[attr-defined]
modin_frame._propagate_index_objs(None)
if axis is None:

Expand Down Expand Up @@ -122,7 +122,7 @@ def get_block(partition: PartitionUnionType) -> np.ndarray:
]

actual_engine = type(
api_layer_object._query_compiler._modin_frame._partitions[0][0]
api_layer_object._query_compiler._modin_frame._partitions[0][0] # type: ignore[attr-defined]
).__name__
if actual_engine in (
"PandasOnRayDataframePartition",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pandas.core.common import is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_integer_dtype

from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler
from modin.core.storage_formats import BaseQueryCompiler
from modin.core.storage_formats.base.query_compiler import (
_get_axis as default_axis_getter,
)
Expand Down
21 changes: 15 additions & 6 deletions modin/pandas/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
CachedAccessor implements API of pandas.core.accessor.CachedAccessor
"""

from __future__ import annotations

import pickle
import warnings
from typing import TYPE_CHECKING, Union

import pandas
from pandas._typing import CompressionOptions, StorageOptions
Expand All @@ -34,6 +37,9 @@
from modin.pandas.io import to_dask, to_ray
from modin.utils import _inherit_docstrings

if TYPE_CHECKING:
from modin.pandas import DataFrame, Series


class BaseSparseAccessor(ClassLogger):
"""
Expand All @@ -45,20 +51,21 @@ class BaseSparseAccessor(ClassLogger):
Object to operate on.
"""

_parent: Union[DataFrame, Series]
_validation_msg = "Can only use the '.sparse' accessor with Sparse data."

def __init__(self, data=None):
def __init__(self, data: Union[DataFrame, Series] = None):
self._parent = data
self._validate(data)

@classmethod
def _validate(cls, data):
def _validate(cls, data: Union[DataFrame, Series]):
"""
Verify that `data` dtypes are compatible with `pandas.core.dtypes.dtypes.SparseDtype`.
Parameters
----------
data : DataFrame
data : DataFrame or Series
Object to check.
Raises
Expand Down Expand Up @@ -94,7 +101,7 @@ def _default_to_pandas(self, op, *args, **kwargs):
@_inherit_docstrings(pandas.core.arrays.sparse.accessor.SparseFrameAccessor)
class SparseFrameAccessor(BaseSparseAccessor):
@classmethod
def _validate(cls, data):
def _validate(cls, data: DataFrame):
"""
Verify that `data` dtypes are compatible with `pandas.core.dtypes.dtypes.SparseDtype`.
Expand Down Expand Up @@ -133,7 +140,7 @@ def to_coo(self):
@_inherit_docstrings(pandas.core.arrays.sparse.accessor.SparseAccessor)
class SparseAccessor(BaseSparseAccessor):
@classmethod
def _validate(cls, data):
def _validate(cls, data: Series):
"""
Verify that `data` dtype is compatible with `pandas.core.dtypes.dtypes.SparseDtype`.
Expand Down Expand Up @@ -208,7 +215,9 @@ class ModinAPI:
Object to operate on.
"""

def __init__(self, data):
_data: Union[DataFrame, Series]

def __init__(self, data: Union[DataFrame, Series]):
self._data = data

def to_pandas(self):
Expand Down
8 changes: 6 additions & 2 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pickle as pkl
import re
import warnings
from typing import Any, Hashable, Literal, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Hashable, Literal, Optional, Sequence, Union

import numpy as np
import pandas
Expand Down Expand Up @@ -67,6 +67,9 @@

from .utils import _doc_binary_op, is_full_grab_slice

if TYPE_CHECKING:
from modin.core.storage_formats import BaseQueryCompiler

# Similar to pandas, sentinel value to use as kwarg in place of None when None has
# special meaning and needs to be distinguished from a user explicitly passing None.
sentinel = object()
Expand Down Expand Up @@ -174,6 +177,7 @@ class BasePandasDataset(ClassLogger):
# Pandas class that we pretend to be; usually it has the same name as our class
# but lives in "pandas" namespace.
_pandas_class = pandas.core.generic.NDFrame
_query_compiler: BaseQueryCompiler

@pandas.util.cache_readonly
def _is_dataframe(self) -> bool:
Expand Down Expand Up @@ -577,7 +581,7 @@ def _get_axis_number(cls, axis):
return cls._pandas_class._get_axis_number(axis) if axis is not None else 0

@pandas.util.cache_readonly
def __constructor__(self):
def __constructor__(self) -> BasePandasDataset:
"""
Construct DataFrame or Series object depending on self type.
Expand Down
7 changes: 5 additions & 2 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import re
import sys
import warnings
from typing import IO, Hashable, Iterator, Optional, Sequence, Union
from typing import IO, TYPE_CHECKING, Hashable, Iterator, Optional, Sequence, Union

import numpy as np
import pandas
Expand Down Expand Up @@ -70,6 +70,9 @@
cast_function_modin2pandas,
)

if TYPE_CHECKING:
from modin.core.storage_formats import BaseQueryCompiler

# Dictionary of extensions assigned to this class
_DATAFRAME_EXTENSIONS_ = {}

Expand Down Expand Up @@ -129,7 +132,7 @@ def __init__(
columns=None,
dtype=None,
copy=None,
query_compiler=None,
query_compiler: BaseQueryCompiler = None,
):
from modin.numpy import array

Expand Down
2 changes: 1 addition & 1 deletion modin/pandas/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pandas._typing import DtypeBackend
from pandas.core.dtypes.common import is_list_like

from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler
from modin.core.storage_formats import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.logging import enable_logging
from modin.pandas.io import to_pandas
Expand Down
10 changes: 9 additions & 1 deletion modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

"""Implement GroupBy public API as pandas does."""

from __future__ import annotations

import warnings
from collections.abc import Iterable
from types import BuiltinFunctionType
from typing import TYPE_CHECKING, Union

import numpy as np
import pandas
Expand Down Expand Up @@ -50,6 +53,9 @@
from .utils import is_label
from .window import RollingGroupby

if TYPE_CHECKING:
from modin.pandas import DataFrame

_DEFAULT_BEHAVIOUR = {
"__class__",
"__getitem__",
Expand Down Expand Up @@ -85,10 +91,12 @@
class DataFrameGroupBy(ClassLogger):
_pandas_class = pandas.core.groupby.DataFrameGroupBy
_return_tuple_when_iterating = False
_df: Union[DataFrame, Series]
_query_compiler: BaseQueryCompiler

def __init__(
self,
df,
df: Union[DataFrame, Series],
by,
axis,
level,
Expand Down
17 changes: 13 additions & 4 deletions modin/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
https://github.com/ray-project/ray/pull/1955#issuecomment-386781826
"""

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Union

import numpy as np
import pandas
Expand All @@ -44,6 +47,9 @@
from .series import Series
from .utils import is_scalar

if TYPE_CHECKING:
from modin.core.storage_formats import BaseQueryCompiler


def is_slice(x):
"""
Expand Down Expand Up @@ -273,11 +279,14 @@ class _LocationIndexerBase(ClassLogger):
Parameters
----------
modin_df : modin.pandas.DataFrame
modin_df : Union[DataFrame, Series]
DataFrame to operate on.
"""

def __init__(self, modin_df):
df: Union[DataFrame, Series]
qc: BaseQueryCompiler

def __init__(self, modin_df: Union[DataFrame, Series]):
self.df = modin_df
self.qc = modin_df._query_compiler

Expand Down Expand Up @@ -612,7 +621,7 @@ class _LocIndexer(_LocationIndexerBase):
Parameters
----------
modin_df : modin.pandas.DataFrame
modin_df : Union[DataFrame, Series]
DataFrame to operate on.
"""

Expand Down Expand Up @@ -967,7 +976,7 @@ class _iLocIndexer(_LocationIndexerBase):
Parameters
----------
modin_df : modin.pandas.DataFrame
modin_df : Union[DataFrame, Series]
DataFrame to operate on.
"""

Expand Down
10 changes: 9 additions & 1 deletion modin/pandas/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

"""Place to define the Modin iterator."""

from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from modin.pandas import DataFrame


class PartitionIterator(Iterator):
Expand All @@ -30,7 +36,9 @@ class PartitionIterator(Iterator):
The function to get inner iterables from each partition.
"""

def __init__(self, df, axis, func):
df: DataFrame

def __init__(self, df: DataFrame, axis, func):
self.df = df
self.axis = axis
self.index_iter = (
Expand Down
13 changes: 11 additions & 2 deletions modin/pandas/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

"""Implement Resampler public API."""

from typing import Optional
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union

import numpy as np
import pandas
Expand All @@ -25,12 +27,19 @@
from modin.pandas.utils import cast_function_modin2pandas
from modin.utils import _inherit_docstrings

if TYPE_CHECKING:
from modin.core.storage_formats import BaseQueryCompiler
from modin.pandas import DataFrame, Series


@_inherit_docstrings(pandas.core.resample.Resampler)
class Resampler(ClassLogger):
_dataframe: Union[DataFrame, Series]
_query_compiler: BaseQueryCompiler

def __init__(
self,
dataframe,
dataframe: Union[DataFrame, Series],
rule,
axis=0,
closed=None,
Expand Down
4 changes: 3 additions & 1 deletion modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from .utils import _doc_binary_op, cast_function_modin2pandas, is_scalar

if TYPE_CHECKING:
from modin.core.storage_formats import BaseQueryCompiler

from .dataframe import DataFrame

# Dictionary of extensions assigned to this class
Expand Down Expand Up @@ -98,7 +100,7 @@ def __init__(
name=None,
copy=None,
fastpath=lib.no_default,
query_compiler=None,
query_compiler: BaseQueryCompiler = None,
):
from modin.numpy import array

Expand Down
Loading

0 comments on commit c1865b8

Please sign in to comment.