diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 561d320af0527..d13e7fa683513 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -72,7 +72,8 @@ from ray.data.impl.compute import cache_wrapper, CallableClass from ray.data.impl.output_buffer import BlockOutputBuffer from ray.data.impl.progress_bar import ProgressBar -from ray.data.impl.shuffle import simple_shuffle, _shuffle_reduce +from ray.data.impl.shuffle import simple_shuffle +from ray.data.impl.fast_repartition import fast_repartition from ray.data.impl.sort import sort_impl from ray.data.impl.block_list import BlockList from ray.data.impl.lazy_block_list import LazyBlockList @@ -462,74 +463,34 @@ def repartition(self, num_blocks: int, *, shuffle: bool = False) -> "Dataset[T]" if shuffle: - def do_shuffle(blocks, clear_input_blocks: bool): - # TODO: implement clear_input_blocks - return simple_shuffle(blocks, num_blocks) + def do_shuffle(block_list, clear_input_blocks: bool, block_udf): + if clear_input_blocks: + blocks = block_list.copy() + block_list.clear() + else: + blocks = block_list + return simple_shuffle(blocks, block_udf, num_blocks) plan = self._plan.with_stage( - AllToAllStage("repartition", num_blocks, do_shuffle) + AllToAllStage( + "repartition", num_blocks, do_shuffle, supports_block_udf=True + ) ) - return Dataset(plan, self._epoch, self._lazy) - - def do_fast(blocks, clear_input_blocks: bool): - # TODO: this won't work in lazy mode since it references `self`. - # TODO: implement clear_input_blocks. - # Compute the (n-1) indices needed for an equal split of the data. - count = self.count() - indices = [] - cur_idx = 0 - for _ in range(num_blocks - 1): - cur_idx += count / num_blocks - indices.append(int(cur_idx)) - assert len(indices) < num_blocks, (indices, num_blocks) - if indices: - splits = self.split_at_indices(indices) - else: - splits = [self] - # TODO(ekl) include stats for the split tasks. We may also want to - # consider combining the split and coalesce tasks as an optimization. - - # Coalesce each split into a single block. - reduce_task = cached_remote_fn(_shuffle_reduce).options(num_returns=2) - reduce_bar = ProgressBar("Repartition", position=0, total=len(splits)) - reduce_out = [ - reduce_task.remote(*s.get_internal_block_refs()) - for s in splits - if s.num_blocks() > 0 - ] - del splits # Early-release memory. - new_blocks, new_metadata = zip(*reduce_out) - new_blocks, new_metadata = list(new_blocks), list(new_metadata) - new_metadata = reduce_bar.fetch_until_complete(new_metadata) - reduce_bar.close() - - # Handle empty blocks. - if len(new_blocks) < num_blocks: - from ray.data.impl.arrow_block import ArrowBlockBuilder - from ray.data.impl.pandas_block import PandasBlockBuilder - from ray.data.impl.simple_block import SimpleBlockBuilder - - num_empties = num_blocks - len(new_blocks) - dataset_format = self._dataset_format() - if dataset_format == "arrow": - builder = ArrowBlockBuilder() - elif dataset_format == "pandas": - builder = PandasBlockBuilder() + + else: + + def do_fast_repartition(block_list, clear_input_blocks: bool, _): + if clear_input_blocks: + blocks = block_list.copy() + block_list.clear() else: - builder = SimpleBlockBuilder() - empty_block = builder.build() - empty_meta = BlockAccessor.for_block(empty_block).get_metadata( - input_files=None, exec_stats=None - ) # No stats for empty block. - empty_blocks, empty_metadata = zip( - *[(ray.put(empty_block), empty_meta) for _ in range(num_empties)] - ) - new_blocks += empty_blocks - new_metadata += empty_metadata + blocks = block_list + return fast_repartition(blocks, num_blocks) - return BlockList(new_blocks, new_metadata), {} + plan = self._plan.with_stage( + AllToAllStage("repartition", num_blocks, do_fast_repartition) + ) - plan = self._plan.with_stage(AllToAllStage("repartition", num_blocks, do_fast)) return Dataset(plan, self._epoch, self._lazy) def random_shuffle( @@ -538,7 +499,7 @@ def random_shuffle( seed: Optional[int] = None, num_blocks: Optional[int] = None, _spread_resource_prefix: Optional[str] = None, - _move: bool = False, + _move: bool = False, # TODO: deprecate. ) -> "Dataset[T]": """Randomly shuffle the elements of this dataset. @@ -563,18 +524,18 @@ def random_shuffle( The shuffled dataset. """ - def do_shuffle(block_list, clear_input_blocks: bool): + def do_shuffle(block_list, clear_input_blocks: bool, block_udf): num_blocks = block_list.executed_num_blocks() # Blocking. if num_blocks == 0: return block_list, {} - # TODO: implement clear_input_blocks instead. - if _move: + if _move or clear_input_blocks: blocks = block_list.copy() block_list.clear() else: blocks = block_list new_blocks, stage_info = simple_shuffle( blocks, + block_udf, num_blocks, random_shuffle=True, random_seed=seed, @@ -583,7 +544,9 @@ def do_shuffle(block_list, clear_input_blocks: bool): return new_blocks, stage_info plan = self._plan.with_stage( - AllToAllStage("random_shuffle", num_blocks, do_shuffle) + AllToAllStage( + "random_shuffle", num_blocks, do_shuffle, supports_block_udf=True + ) ) return Dataset(plan, self._epoch, self._lazy) @@ -1005,7 +968,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": ) _epoch_warned = True dataset_stats = DatasetStats( - stages={"union": []}, parent=[d._plan.stats() for d in datasets] + stages={"union": []}, + parent=[d._plan.stats() for d in datasets], ) dataset_stats.time_total_s = time.perf_counter() - start_time return Dataset( @@ -1409,11 +1373,15 @@ def sort( A new, sorted dataset. """ - def do_sort(blocks, keep_input_blocks: bool): - # TODO: implement clear_input_blocks + def do_sort(block_list, clear_input_blocks: bool, block_udf): # Handle empty dataset. - if blocks.initial_num_blocks() == 0: - return blocks, {} + if block_list.initial_num_blocks() == 0: + return block_list, {} + if clear_input_blocks: + blocks = block_list.copy() + block_list.clear() + else: + blocks = block_list if isinstance(key, list): if not key: raise ValueError("`key` must be a list of non-zero length") @@ -1449,11 +1417,13 @@ def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": comes from the first dataset and v comes from the second. """ - def do_zip_all(blocks, keep_input_blocks: bool): - # TODO: implement clear_input_blocks - blocks1 = blocks.get_blocks() + def do_zip_all(block_list, clear_input_blocks: bool, block_udf): + blocks1 = block_list.get_blocks() blocks2 = other.get_internal_block_refs() + if clear_input_blocks: + block_list.clear() + if len(blocks1) != len(blocks2): # TODO(ekl) consider supporting if num_rows are equal. raise ValueError( @@ -1478,6 +1448,9 @@ def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata): blocks.append(res) metadata.append(meta) + # Early release memory. + del blocks1, blocks2 + # TODO(ekl) it might be nice to have a progress bar here. metadata = ray.get(metadata) blocks = BlockList(blocks, metadata) diff --git a/python/ray/data/grouped_dataset.py b/python/ray/data/grouped_dataset.py index d3dd70cdbf7dc..be69539a17886 100644 --- a/python/ray/data/grouped_dataset.py +++ b/python/ray/data/grouped_dataset.py @@ -56,7 +56,7 @@ def aggregate(self, *aggs: AggregateFn) -> Dataset[U]: If groupby key is ``None`` then the key part of return is omitted. """ - def do_agg(blocks, clear_input_blocks: bool): + def do_agg(blocks, clear_input_blocks: bool, block_udf): # TODO: implement clear_input_blocks stage_info = {} if len(aggs) == 0: diff --git a/python/ray/data/impl/fast_repartition.py b/python/ray/data/impl/fast_repartition.py new file mode 100644 index 0000000000000..c7207ef9310c2 --- /dev/null +++ b/python/ray/data/impl/fast_repartition.py @@ -0,0 +1,74 @@ +import ray + +from ray.data.block import BlockAccessor +from ray.data.impl.block_list import BlockList +from ray.data.impl.plan import ExecutionPlan +from ray.data.impl.progress_bar import ProgressBar +from ray.data.impl.remote_fn import cached_remote_fn +from ray.data.impl.shuffle import _shuffle_reduce +from ray.data.impl.stats import DatasetStats + + +def fast_repartition(blocks, num_blocks): + from ray.data.dataset import Dataset + + wrapped_ds = Dataset( + ExecutionPlan(blocks, DatasetStats(stages={}, parent=None)), 0, lazy=False + ) + # Compute the (n-1) indices needed for an equal split of the data. + count = wrapped_ds.count() + dataset_format = wrapped_ds._dataset_format() + indices = [] + cur_idx = 0 + for _ in range(num_blocks - 1): + cur_idx += count / num_blocks + indices.append(int(cur_idx)) + assert len(indices) < num_blocks, (indices, num_blocks) + if indices: + splits = wrapped_ds.split_at_indices(indices) + else: + splits = [wrapped_ds] + # TODO(ekl) include stats for the split tasks. We may also want to + # consider combining the split and coalesce tasks as an optimization. + + # Coalesce each split into a single block. + reduce_task = cached_remote_fn(_shuffle_reduce).options(num_returns=2) + reduce_bar = ProgressBar("Repartition", position=0, total=len(splits)) + reduce_out = [ + reduce_task.remote(*s.get_internal_block_refs()) + for s in splits + if s.num_blocks() > 0 + ] + + # Early-release memory. + del splits, blocks, wrapped_ds + + new_blocks, new_metadata = zip(*reduce_out) + new_blocks, new_metadata = list(new_blocks), list(new_metadata) + new_metadata = reduce_bar.fetch_until_complete(new_metadata) + reduce_bar.close() + + # Handle empty blocks. + if len(new_blocks) < num_blocks: + from ray.data.impl.arrow_block import ArrowBlockBuilder + from ray.data.impl.pandas_block import PandasBlockBuilder + from ray.data.impl.simple_block import SimpleBlockBuilder + + num_empties = num_blocks - len(new_blocks) + if dataset_format == "arrow": + builder = ArrowBlockBuilder() + elif dataset_format == "pandas": + builder = PandasBlockBuilder() + else: + builder = SimpleBlockBuilder() + empty_block = builder.build() + empty_meta = BlockAccessor.for_block(empty_block).get_metadata( + input_files=None, exec_stats=None + ) # No stats for empty block. + empty_blocks, empty_metadata = zip( + *[(ray.put(empty_block), empty_meta) for _ in range(num_empties)] + ) + new_blocks += empty_blocks + new_metadata += empty_metadata + + return BlockList(new_blocks, new_metadata), {} diff --git a/python/ray/data/impl/plan.py b/python/ray/data/impl/plan.py index a59cf52d8dc0b..a05b2a6372f6e 100644 --- a/python/ray/data/impl/plan.py +++ b/python/ray/data/impl/plan.py @@ -140,14 +140,18 @@ def __init__( self, name: str, num_blocks: Optional[int], - fn: Callable[[BlockList, bool], Tuple[BlockList, dict]], + fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]], + supports_block_udf: bool = False, + block_udf=None, ): super().__init__(name, num_blocks) self.fn = fn + self.supports_block_udf = supports_block_udf + self.block_udf = block_udf def __call__( self, blocks: BlockList, clear_input_blocks: bool ) -> Tuple[BlockList, dict]: - blocks, stage_info = self.fn(blocks, clear_input_blocks) + blocks, stage_info = self.fn(blocks, clear_input_blocks, self.block_udf) assert isinstance(blocks, BlockList), blocks return blocks, stage_info diff --git a/python/ray/data/impl/shuffle.py b/python/ray/data/impl/shuffle.py index ae8604791e030..9e903bfe97eb6 100644 --- a/python/ray/data/impl/shuffle.py +++ b/python/ray/data/impl/shuffle.py @@ -1,6 +1,6 @@ import itertools import math -from typing import TypeVar, List, Optional, Dict, Any, Tuple, Union +from typing import TypeVar, List, Optional, Dict, Any, Tuple, Union, Callable, Iterable import numpy as np @@ -17,6 +17,7 @@ def simple_shuffle( input_blocks: BlockList, + block_udf: Optional[Callable[[Block], Iterable[Block]]], output_num_blocks: int, *, random_shuffle: bool = False, @@ -55,7 +56,7 @@ def simple_shuffle( **map_ray_remote_args, num_returns=1 + output_num_blocks, resources=next(map_resource_iter) - ).remote(block, i, output_num_blocks, random_shuffle, random_seed) + ).remote(block, block_udf, i, output_num_blocks, random_shuffle, random_seed) for i, block in enumerate(input_blocks) ] @@ -103,6 +104,7 @@ def simple_shuffle( def _shuffle_map( block: Block, + block_udf: Optional[Callable[[Block], Iterable[Block]]], idx: int, output_num_blocks: int, random_shuffle: bool, @@ -110,6 +112,16 @@ def _shuffle_map( ) -> List[Union[BlockMetadata, Block]]: """Returns list of [BlockMetadata, O1, O2, O3, ...output_num_blocks].""" stats = BlockExecStats.builder() + if block_udf: + # TODO(ekl) note that this effectively disables block splitting. + pieces = list(block_udf(block)) + if len(pieces) > 1: + builder = BlockAccessor.for_block(pieces[0]).builder() + for p in pieces: + builder.add_block(p) + block = builder.build() + else: + block = pieces[0] block = BlockAccessor.for_block(block) # Randomize the distribution of records to blocks. diff --git a/python/ray/data/impl/sort.py b/python/ray/data/impl/sort.py index 358ba76d84e4a..c0f973257c994 100644 --- a/python/ray/data/impl/sort.py +++ b/python/ray/data/impl/sort.py @@ -103,6 +103,10 @@ def sort_impl( result = sort_block.remote(block, boundaries, key, descending) map_results[i, :] = result[:-1] map_meta.append(result[-1]) + + # Early release memory. + del blocks + map_bar = ProgressBar("Sort Map", len(map_results)) map_bar.block_until_complete(map_meta) map_bar.close() diff --git a/python/ray/data/impl/stats.py b/python/ray/data/impl/stats.py index 6edc662ef87f9..3572c711a5c50 100644 --- a/python/ray/data/impl/stats.py +++ b/python/ray/data/impl/stats.py @@ -63,7 +63,8 @@ def build_multistage( def build(self, final_blocks: BlockList) -> "DatasetStats": stats = DatasetStats( - stages={self.stage_name: final_blocks.get_metadata()}, parent=self.parent + stages={self.stage_name: final_blocks.get_metadata()}, + parent=self.parent, ) stats.time_total_s = time.perf_counter() - self.start_time return stats @@ -205,7 +206,7 @@ def summary_string(self, already_printed: Set[str] = None) -> str: out += p.summary_string(already_printed) out += "\n" first = True - for stage_name, metadata in sorted(self.stages.items()): + for stage_name, metadata in self.stages.items(): stage_uuid = self.dataset_uuid + stage_name if first: first = False