Skip to content

Commit 9b39965

Browse files
authored
[Data] Simplify ArrowBlock and PandasBlock (#58883)
### [Data] Simplify ArrowBlock and PandasBlock Simplify inheritance hierarchy for `ArrowBlock` and `PandasBlock` by removing `TableRow` to improve code maintainability. Signed-off-by: Srinath Krishnamachari <srinath.krishnamachari@anyscale.com>
1 parent d50cb5b commit 9b39965

File tree

4 files changed

+66
-63
lines changed

4 files changed

+66
-63
lines changed

python/ray/data/_internal/arrow_block.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
2828
from ray.data._internal.arrow_ops.transform_pyarrow import shuffle
29-
from ray.data._internal.row import TableRow
29+
from ray.data._internal.row import row_repr, row_repr_pretty, row_str
3030
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
3131
from ray.data.block import (
3232
Block,
@@ -85,11 +85,14 @@ def get_concat_and_sort_transform(context: DataContext) -> Callable:
8585
return transform_pyarrow.concat_and_sort
8686

8787

88-
class ArrowRow(TableRow):
88+
class ArrowRow(Mapping):
8989
"""
9090
Row of a tabular Dataset backed by a Arrow Table block.
9191
"""
9292

93+
def __init__(self, row: Any):
94+
self._row = row
95+
9396
def __getitem__(self, key: Union[str, List[str]]) -> Any:
9497
from ray.data.extensions import get_arrow_extension_tensor_types
9598

@@ -101,7 +104,9 @@ def get_item(keys: List[str]) -> Any:
101104
# Build a tensor row.
102105
return tuple(
103106
[
104-
ArrowBlockAccessor._build_tensor_row(self._row, col_name=key)
107+
ArrowBlockAccessor._build_tensor_row(
108+
self._row, col_name=key, row_idx=0
109+
)
105110
for key in keys
106111
]
107112
)
@@ -142,6 +147,15 @@ def __len__(self):
142147
def as_pydict(self) -> Dict[str, Any]:
143148
return dict(self.items())
144149

150+
def __str__(self):
151+
return row_str(self)
152+
153+
def __repr__(self):
154+
return row_repr(self)
155+
156+
def _repr_pretty_(self, p, cycle):
157+
return row_repr_pretty(self, p, cycle)
158+
145159

146160
class ArrowBlockBuilder(TableBlockBuilder):
147161
def __init__(self):
@@ -203,6 +217,11 @@ def __init__(self, table: "pyarrow.Table"):
203217
if pyarrow is None:
204218
raise ImportError("Run `pip install pyarrow` for Arrow support")
205219
super().__init__(table)
220+
self._max_chunk_size: Optional[int] = None
221+
222+
def _get_row(self, index: int) -> ArrowRow:
223+
base_row = self.slice(index, index + 1, copy=False)
224+
return ArrowRow(base_row)
206225

207226
def column_names(self) -> List[str]:
208227
return self._table.column_names
@@ -231,10 +250,10 @@ def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":
231250

232251
@staticmethod
233252
def _build_tensor_row(
234-
row: ArrowRow, col_name: str = TENSOR_COLUMN_NAME
253+
row: ArrowRow, row_idx: int, col_name: str = TENSOR_COLUMN_NAME
235254
) -> np.ndarray:
236255

237-
element = row[col_name][0]
256+
element = row[col_name][row_idx]
238257
arr = element.as_py()
239258

240259
assert isinstance(arr, np.ndarray), type(arr)
@@ -444,16 +463,17 @@ def iter_rows(
444463
) -> Iterator[Union[Mapping, np.ndarray]]:
445464
table = self._table
446465
if public_row_format:
447-
if not hasattr(self, "_max_chunk_size"):
466+
if self._max_chunk_size is None:
448467
# Calling _get_max_chunk_size in constructor makes it slow, so we
449468
# are calling it here only when needed.
450469
self._max_chunk_size = _get_max_chunk_size(
451-
self._table, ARROW_MAX_CHUNK_SIZE_BYTES
470+
table, ARROW_MAX_CHUNK_SIZE_BYTES
452471
)
453472
for batch in table.to_batches(max_chunksize=self._max_chunk_size):
454473
yield from batch.to_pylist()
455474
else:
456-
for i in range(self.num_rows()):
475+
num_rows = self.num_rows()
476+
for i in range(num_rows):
457477
yield self._get_row(i)
458478

459479
def filter(self, predicate_expr: "Expr") -> "pyarrow.Table":

python/ray/data/_internal/pandas_block.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ray.air.constants import TENSOR_COLUMN_NAME
2222
from ray.air.util.tensor_extensions.utils import _should_convert_to_tensor
2323
from ray.data._internal.numpy_support import convert_to_numpy
24-
from ray.data._internal.row import TableRow
24+
from ray.data._internal.row import row_repr, row_repr_pretty, row_str
2525
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
2626
from ray.data._internal.util import is_null
2727
from ray.data.block import (
@@ -61,11 +61,14 @@ def lazy_import_pandas():
6161
return _pandas
6262

6363

64-
class PandasRow(TableRow):
64+
class PandasRow(Mapping):
6565
"""
6666
Row of a tabular Dataset backed by a Pandas DataFrame block.
6767
"""
6868

69+
def __init__(self, row: Any):
70+
self._row = row
71+
6972
def __getitem__(self, key: Union[str, List[str]]) -> Any:
7073
from ray.data.extensions import TensorArrayElement
7174

@@ -124,6 +127,15 @@ def as_pydict(self) -> Dict[str, Any]:
124127

125128
return pydict
126129

130+
def __str__(self):
131+
return row_str(self)
132+
133+
def __repr__(self):
134+
return row_repr(self)
135+
136+
def _repr_pretty_(self, p, cycle):
137+
return row_repr_pretty(self, p, cycle)
138+
127139

128140
class PandasBlockColumnAccessor(BlockColumnAccessor):
129141
def __init__(self, col: "pandas.Series"):
@@ -330,6 +342,10 @@ class PandasBlockAccessor(TableBlockAccessor):
330342
def __init__(self, table: "pandas.DataFrame"):
331343
super().__init__(table)
332344

345+
def _get_row(self, index: int) -> PandasRow:
346+
base_row = self.slice(index, index + 1, copy=False)
347+
return PandasRow(base_row)
348+
333349
def column_names(self) -> List[str]:
334350
return self._table.columns.tolist()
335351

@@ -341,10 +357,10 @@ def fill_column(self, name: str, value: Any) -> Block:
341357
return self._table.assign(**{name: value})
342358

343359
@staticmethod
344-
def _build_tensor_row(row: PandasRow) -> np.ndarray:
360+
def _build_tensor_row(row: PandasRow, row_idx: int) -> np.ndarray:
345361
from ray.data.extensions import TensorArrayElement
346362

347-
tensor = row[TENSOR_COLUMN_NAME].iloc[0]
363+
tensor = row[TENSOR_COLUMN_NAME].iloc[row_idx]
348364
if isinstance(tensor, TensorArrayElement):
349365
# Getting an item in a Pandas tensor column may return a TensorArrayElement,
350366
# which we have to convert to an ndarray.
@@ -664,9 +680,10 @@ def block_type(self) -> BlockType:
664680
def iter_rows(
665681
self, public_row_format: bool
666682
) -> Iterator[Union[Mapping, np.ndarray]]:
667-
for i in range(self.num_rows()):
683+
num_rows = self.num_rows()
684+
for i in range(num_rows):
668685
row = self._get_row(i)
669-
if public_row_format and isinstance(row, TableRow):
686+
if public_row_format:
670687
yield row.as_pydict()
671688
else:
672689
yield row

python/ray/data/_internal/row.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,19 @@
1-
import abc
21
from collections.abc import Mapping
3-
from typing import Any, Dict
42

53

6-
class TableRow(Mapping):
7-
"""
8-
A dict-like row of a tabular ``Dataset``.
4+
def row_str(row: Mapping) -> str:
5+
"""Convert a row to string representation."""
6+
return str(row.as_pydict())
97

10-
This implements the dictionary mapping interface, but provides more
11-
efficient access with less data copying than converting Arrow Tables
12-
or Pandas DataFrames into per-row dicts. This class must be subclassed,
13-
with subclasses implementing ``__getitem__``, ``__iter__``, and ``__len__``.
148

15-
Concrete subclasses include ``ray.data._internal.arrow_block.ArrowRow`` and
16-
``ray.data._internal.pandas_block.PandasRow``.
17-
"""
9+
def row_repr(row: Mapping) -> str:
10+
"""Convert a row to repr representation."""
11+
return str(row)
1812

19-
def __init__(self, row: Any):
20-
"""
21-
Construct a ``TableRow`` (internal API).
2213

23-
Args:
24-
row: The tabular row that backs this row mapping.
25-
"""
26-
self._row = row
14+
def row_repr_pretty(row: Mapping, p, cycle):
15+
"""Pretty print a row."""
16+
from IPython.lib.pretty import _dict_pprinter_factory
2717

28-
@abc.abstractmethod
29-
def as_pydict(self) -> Dict[str, Any]:
30-
"""Convert to a normal Python dict.
31-
32-
This can create a new copy of the row.
33-
"""
34-
...
35-
36-
def __str__(self):
37-
return str(self.as_pydict())
38-
39-
def __repr__(self):
40-
return str(self)
41-
42-
def _repr_pretty_(self, p, cycle):
43-
from IPython.lib.pretty import _dict_pprinter_factory
44-
45-
pprinter = _dict_pprinter_factory("{", "}")
46-
return pprinter(self, p, cycle)
18+
pprinter = _dict_pprinter_factory("{", "}")
19+
return pprinter(row, p, cycle)

python/ray/data/_internal/table_block.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Dict,
77
Iterator,
88
List,
9+
Mapping,
910
Optional,
1011
Sequence,
1112
Tuple,
@@ -18,7 +19,6 @@
1819
from ray._private.ray_constants import env_integer
1920
from ray.air.constants import TENSOR_COLUMN_NAME
2021
from ray.data._internal.block_builder import BlockBuilder
21-
from ray.data._internal.row import TableRow
2222
from ray.data._internal.size_estimator import SizeEstimator
2323
from ray.data._internal.util import (
2424
NULL_SENTINEL,
@@ -73,8 +73,8 @@ def __init__(self, block_type):
7373
self._num_compactions = 0
7474
self._block_type = block_type
7575

76-
def add(self, item: Union[dict, TableRow, np.ndarray]) -> None:
77-
if isinstance(item, TableRow):
76+
def add(self, item: Union[dict, Mapping, np.ndarray]) -> None:
77+
if hasattr(item, "as_pydict"):
7878
item = item.as_pydict()
7979
elif isinstance(item, np.ndarray):
8080
item = {TENSOR_COLUMN_NAME: item}
@@ -169,22 +169,15 @@ def _compact_if_needed(self) -> None:
169169

170170

171171
class TableBlockAccessor(BlockAccessor):
172-
ROW_TYPE: TableRow = TableRow
173-
174172
def __init__(self, table: Any):
175173
self._table = table
176174

177-
def _get_row(self, index: int, copy: bool = False) -> Union[TableRow, np.ndarray]:
178-
base_row = self.slice(index, index + 1, copy=copy)
179-
row = self.ROW_TYPE(base_row)
180-
return row
181-
182175
@staticmethod
183176
def _munge_conflict(name, count):
184177
return f"{name}_{count + 1}"
185178

186179
@staticmethod
187-
def _build_tensor_row(row: TableRow) -> np.ndarray:
180+
def _build_tensor_row(row: Mapping, row_idx: int) -> np.ndarray:
188181
raise NotImplementedError
189182

190183
def to_default(self) -> Block:

0 commit comments

Comments
 (0)