-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[Data] Simplify ArrowBlock and PandasBlock #58883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |||||
| from ray.air.constants import TENSOR_COLUMN_NAME | ||||||
| from ray.air.util.tensor_extensions.utils import _should_convert_to_tensor | ||||||
| from ray.data._internal.numpy_support import convert_to_numpy | ||||||
| from ray.data._internal.row import TableRow | ||||||
| from ray.data._internal.row import row_repr, row_repr_pretty, row_str | ||||||
| from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder | ||||||
| from ray.data._internal.util import is_null | ||||||
| from ray.data.block import ( | ||||||
|
|
@@ -61,11 +61,14 @@ def lazy_import_pandas(): | |||||
| return _pandas | ||||||
|
|
||||||
|
|
||||||
| class PandasRow(TableRow): | ||||||
| class PandasRow(Mapping): | ||||||
| """ | ||||||
| Row of a tabular Dataset backed by a Pandas DataFrame block. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, row: Any): | ||||||
| self._row = row | ||||||
|
|
||||||
| def __getitem__(self, key: Union[str, List[str]]) -> Any: | ||||||
| from ray.data.extensions import TensorArrayElement | ||||||
|
|
||||||
|
|
@@ -124,6 +127,15 @@ def as_pydict(self) -> Dict[str, Any]: | |||||
|
|
||||||
| return pydict | ||||||
|
|
||||||
| def __str__(self): | ||||||
| return row_str(self) | ||||||
|
|
||||||
| def __repr__(self): | ||||||
| return row_repr(self) | ||||||
|
|
||||||
| def _repr_pretty_(self, p, cycle): | ||||||
| return row_repr_pretty(self, p, cycle) | ||||||
|
|
||||||
|
|
||||||
| class PandasBlockColumnAccessor(BlockColumnAccessor): | ||||||
| def __init__(self, col: "pandas.Series"): | ||||||
|
|
@@ -330,6 +342,10 @@ class PandasBlockAccessor(TableBlockAccessor): | |||||
| def __init__(self, table: "pandas.DataFrame"): | ||||||
| super().__init__(table) | ||||||
|
|
||||||
| def _get_row(self, index: int) -> PandasRow: | ||||||
| base_row = self.slice(index, index + 1, copy=False) | ||||||
| return PandasRow(base_row) | ||||||
|
|
||||||
| def column_names(self) -> List[str]: | ||||||
| return self._table.columns.tolist() | ||||||
|
|
||||||
|
|
@@ -341,10 +357,10 @@ def fill_column(self, name: str, value: Any) -> Block: | |||||
| return self._table.assign(**{name: value}) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _build_tensor_row(row: PandasRow) -> np.ndarray: | ||||||
| def _build_tensor_row(row: PandasRow, row_idx: int) -> np.ndarray: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for the
Suggested change
|
||||||
| from ray.data.extensions import TensorArrayElement | ||||||
|
|
||||||
| tensor = row[TENSOR_COLUMN_NAME].iloc[0] | ||||||
| tensor = row[TENSOR_COLUMN_NAME].iloc[row_idx] | ||||||
| if isinstance(tensor, TensorArrayElement): | ||||||
| # Getting an item in a Pandas tensor column may return a TensorArrayElement, | ||||||
| # which we have to convert to an ndarray. | ||||||
|
|
@@ -664,9 +680,10 @@ def block_type(self) -> BlockType: | |||||
| def iter_rows( | ||||||
| self, public_row_format: bool | ||||||
| ) -> Iterator[Union[Mapping, np.ndarray]]: | ||||||
| for i in range(self.num_rows()): | ||||||
| num_rows = self.num_rows() | ||||||
| for i in range(num_rows): | ||||||
| row = self._get_row(i) | ||||||
| if public_row_format and isinstance(row, TableRow): | ||||||
| if public_row_format: | ||||||
| yield row.as_pydict() | ||||||
| else: | ||||||
| yield row | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,46 +1,19 @@ | ||
| import abc | ||
| from collections.abc import Mapping | ||
| from typing import Any, Dict | ||
|
|
||
|
|
||
| class TableRow(Mapping): | ||
| """ | ||
| A dict-like row of a tabular ``Dataset``. | ||
| def row_str(row: Mapping) -> str: | ||
| """Convert a row to string representation.""" | ||
| return str(row.as_pydict()) | ||
|
|
||
| This implements the dictionary mapping interface, but provides more | ||
| efficient access with less data copying than converting Arrow Tables | ||
| or Pandas DataFrames into per-row dicts. This class must be subclassed, | ||
| with subclasses implementing ``__getitem__``, ``__iter__``, and ``__len__``. | ||
|
|
||
| Concrete subclasses include ``ray.data._internal.arrow_block.ArrowRow`` and | ||
| ``ray.data._internal.pandas_block.PandasRow``. | ||
| """ | ||
| def row_repr(row: Mapping) -> str: | ||
| """Convert a row to repr representation.""" | ||
| return str(row) | ||
|
|
||
| def __init__(self, row: Any): | ||
| """ | ||
| Construct a ``TableRow`` (internal API). | ||
|
|
||
| Args: | ||
| row: The tabular row that backs this row mapping. | ||
| """ | ||
| self._row = row | ||
| def row_repr_pretty(row: Mapping, p, cycle): | ||
| """Pretty print a row.""" | ||
| from IPython.lib.pretty import _dict_pprinter_factory | ||
|
|
||
| @abc.abstractmethod | ||
| def as_pydict(self) -> Dict[str, Any]: | ||
| """Convert to a normal Python dict. | ||
|
|
||
| This can create a new copy of the row. | ||
| """ | ||
| ... | ||
|
|
||
| def __str__(self): | ||
| return str(self.as_pydict()) | ||
|
|
||
| def __repr__(self): | ||
| return str(self) | ||
|
|
||
| def _repr_pretty_(self, p, cycle): | ||
| from IPython.lib.pretty import _dict_pprinter_factory | ||
|
|
||
| pprinter = _dict_pprinter_factory("{", "}") | ||
| return pprinter(self, p, cycle) | ||
| pprinter = _dict_pprinter_factory("{", "}") | ||
| return pprinter(row, p, cycle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for the
rowparameter isArrowRow, but it appears to be incorrect. The only call site in this file,ArrowRow.__getitem__, passesself._row, which is apyarrow.Table. Ifrowwere anArrowRowinstance, the expressionrow[col_name]within this method would trigger a recursive call toArrowRow.__getitem__. To improve type safety and clarity, please update the type hint torow: "pyarrow.Table".