Skip to content

Commit

Permalink
[data] Pre-reqs for implementing stage fusion (ray-project#22374)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Feb 15, 2022
1 parent 32035eb commit 2158df3
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 83 deletions.
125 changes: 49 additions & 76 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
from ray.data.impl.compute import cache_wrapper, CallableClass
from ray.data.impl.output_buffer import BlockOutputBuffer
from ray.data.impl.progress_bar import ProgressBar
from ray.data.impl.shuffle import simple_shuffle, _shuffle_reduce
from ray.data.impl.shuffle import simple_shuffle
from ray.data.impl.fast_repartition import fast_repartition
from ray.data.impl.sort import sort_impl
from ray.data.impl.block_list import BlockList
from ray.data.impl.lazy_block_list import LazyBlockList
Expand Down Expand Up @@ -462,74 +463,34 @@ def repartition(self, num_blocks: int, *, shuffle: bool = False) -> "Dataset[T]"

if shuffle:

def do_shuffle(blocks, clear_input_blocks: bool):
# TODO: implement clear_input_blocks
return simple_shuffle(blocks, num_blocks)
def do_shuffle(block_list, clear_input_blocks: bool, block_udf):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
return simple_shuffle(blocks, block_udf, num_blocks)

plan = self._plan.with_stage(
AllToAllStage("repartition", num_blocks, do_shuffle)
AllToAllStage(
"repartition", num_blocks, do_shuffle, supports_block_udf=True
)
)
return Dataset(plan, self._epoch, self._lazy)

def do_fast(blocks, clear_input_blocks: bool):
# TODO: this won't work in lazy mode since it references `self`.
# TODO: implement clear_input_blocks.
# Compute the (n-1) indices needed for an equal split of the data.
count = self.count()
indices = []
cur_idx = 0
for _ in range(num_blocks - 1):
cur_idx += count / num_blocks
indices.append(int(cur_idx))
assert len(indices) < num_blocks, (indices, num_blocks)
if indices:
splits = self.split_at_indices(indices)
else:
splits = [self]
# TODO(ekl) include stats for the split tasks. We may also want to
# consider combining the split and coalesce tasks as an optimization.

# Coalesce each split into a single block.
reduce_task = cached_remote_fn(_shuffle_reduce).options(num_returns=2)
reduce_bar = ProgressBar("Repartition", position=0, total=len(splits))
reduce_out = [
reduce_task.remote(*s.get_internal_block_refs())
for s in splits
if s.num_blocks() > 0
]
del splits # Early-release memory.
new_blocks, new_metadata = zip(*reduce_out)
new_blocks, new_metadata = list(new_blocks), list(new_metadata)
new_metadata = reduce_bar.fetch_until_complete(new_metadata)
reduce_bar.close()

# Handle empty blocks.
if len(new_blocks) < num_blocks:
from ray.data.impl.arrow_block import ArrowBlockBuilder
from ray.data.impl.pandas_block import PandasBlockBuilder
from ray.data.impl.simple_block import SimpleBlockBuilder

num_empties = num_blocks - len(new_blocks)
dataset_format = self._dataset_format()
if dataset_format == "arrow":
builder = ArrowBlockBuilder()
elif dataset_format == "pandas":
builder = PandasBlockBuilder()

else:

def do_fast_repartition(block_list, clear_input_blocks: bool, _):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
builder = SimpleBlockBuilder()
empty_block = builder.build()
empty_meta = BlockAccessor.for_block(empty_block).get_metadata(
input_files=None, exec_stats=None
) # No stats for empty block.
empty_blocks, empty_metadata = zip(
*[(ray.put(empty_block), empty_meta) for _ in range(num_empties)]
)
new_blocks += empty_blocks
new_metadata += empty_metadata
blocks = block_list
return fast_repartition(blocks, num_blocks)

return BlockList(new_blocks, new_metadata), {}
plan = self._plan.with_stage(
AllToAllStage("repartition", num_blocks, do_fast_repartition)
)

plan = self._plan.with_stage(AllToAllStage("repartition", num_blocks, do_fast))
return Dataset(plan, self._epoch, self._lazy)

def random_shuffle(
Expand All @@ -538,7 +499,7 @@ def random_shuffle(
seed: Optional[int] = None,
num_blocks: Optional[int] = None,
_spread_resource_prefix: Optional[str] = None,
_move: bool = False,
_move: bool = False, # TODO: deprecate.
) -> "Dataset[T]":
"""Randomly shuffle the elements of this dataset.
Expand All @@ -563,18 +524,18 @@ def random_shuffle(
The shuffled dataset.
"""

def do_shuffle(block_list, clear_input_blocks: bool):
def do_shuffle(block_list, clear_input_blocks: bool, block_udf):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
# TODO: implement clear_input_blocks instead.
if _move:
if _move or clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
new_blocks, stage_info = simple_shuffle(
blocks,
block_udf,
num_blocks,
random_shuffle=True,
random_seed=seed,
Expand All @@ -583,7 +544,9 @@ def do_shuffle(block_list, clear_input_blocks: bool):
return new_blocks, stage_info

plan = self._plan.with_stage(
AllToAllStage("random_shuffle", num_blocks, do_shuffle)
AllToAllStage(
"random_shuffle", num_blocks, do_shuffle, supports_block_udf=True
)
)
return Dataset(plan, self._epoch, self._lazy)

Expand Down Expand Up @@ -1005,7 +968,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
)
_epoch_warned = True
dataset_stats = DatasetStats(
stages={"union": []}, parent=[d._plan.stats() for d in datasets]
stages={"union": []},
parent=[d._plan.stats() for d in datasets],
)
dataset_stats.time_total_s = time.perf_counter() - start_time
return Dataset(
Expand Down Expand Up @@ -1409,11 +1373,15 @@ def sort(
A new, sorted dataset.
"""

def do_sort(blocks, keep_input_blocks: bool):
# TODO: implement clear_input_blocks
def do_sort(block_list, clear_input_blocks: bool, block_udf):
# Handle empty dataset.
if blocks.initial_num_blocks() == 0:
return blocks, {}
if block_list.initial_num_blocks() == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
Expand Down Expand Up @@ -1449,11 +1417,13 @@ def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
comes from the first dataset and v comes from the second.
"""

def do_zip_all(blocks, keep_input_blocks: bool):
# TODO: implement clear_input_blocks
blocks1 = blocks.get_blocks()
def do_zip_all(block_list, clear_input_blocks: bool, block_udf):
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()

if clear_input_blocks:
block_list.clear()

if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
raise ValueError(
Expand All @@ -1478,6 +1448,9 @@ def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
blocks.append(res)
metadata.append(meta)

# Early release memory.
del blocks1, blocks2

# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
blocks = BlockList(blocks, metadata)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/grouped_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def aggregate(self, *aggs: AggregateFn) -> Dataset[U]:
If groupby key is ``None`` then the key part of return is omitted.
"""

def do_agg(blocks, clear_input_blocks: bool):
def do_agg(blocks, clear_input_blocks: bool, block_udf):
# TODO: implement clear_input_blocks
stage_info = {}
if len(aggs) == 0:
Expand Down
74 changes: 74 additions & 0 deletions python/ray/data/impl/fast_repartition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import ray

from ray.data.block import BlockAccessor
from ray.data.impl.block_list import BlockList
from ray.data.impl.plan import ExecutionPlan
from ray.data.impl.progress_bar import ProgressBar
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.impl.shuffle import _shuffle_reduce
from ray.data.impl.stats import DatasetStats


def fast_repartition(blocks, num_blocks):
from ray.data.dataset import Dataset

wrapped_ds = Dataset(
ExecutionPlan(blocks, DatasetStats(stages={}, parent=None)), 0, lazy=False
)
# Compute the (n-1) indices needed for an equal split of the data.
count = wrapped_ds.count()
dataset_format = wrapped_ds._dataset_format()
indices = []
cur_idx = 0
for _ in range(num_blocks - 1):
cur_idx += count / num_blocks
indices.append(int(cur_idx))
assert len(indices) < num_blocks, (indices, num_blocks)
if indices:
splits = wrapped_ds.split_at_indices(indices)
else:
splits = [wrapped_ds]
# TODO(ekl) include stats for the split tasks. We may also want to
# consider combining the split and coalesce tasks as an optimization.

# Coalesce each split into a single block.
reduce_task = cached_remote_fn(_shuffle_reduce).options(num_returns=2)
reduce_bar = ProgressBar("Repartition", position=0, total=len(splits))
reduce_out = [
reduce_task.remote(*s.get_internal_block_refs())
for s in splits
if s.num_blocks() > 0
]

# Early-release memory.
del splits, blocks, wrapped_ds

new_blocks, new_metadata = zip(*reduce_out)
new_blocks, new_metadata = list(new_blocks), list(new_metadata)
new_metadata = reduce_bar.fetch_until_complete(new_metadata)
reduce_bar.close()

# Handle empty blocks.
if len(new_blocks) < num_blocks:
from ray.data.impl.arrow_block import ArrowBlockBuilder
from ray.data.impl.pandas_block import PandasBlockBuilder
from ray.data.impl.simple_block import SimpleBlockBuilder

num_empties = num_blocks - len(new_blocks)
if dataset_format == "arrow":
builder = ArrowBlockBuilder()
elif dataset_format == "pandas":
builder = PandasBlockBuilder()
else:
builder = SimpleBlockBuilder()
empty_block = builder.build()
empty_meta = BlockAccessor.for_block(empty_block).get_metadata(
input_files=None, exec_stats=None
) # No stats for empty block.
empty_blocks, empty_metadata = zip(
*[(ray.put(empty_block), empty_meta) for _ in range(num_empties)]
)
new_blocks += empty_blocks
new_metadata += empty_metadata

return BlockList(new_blocks, new_metadata), {}
8 changes: 6 additions & 2 deletions python/ray/data/impl/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,18 @@ def __init__(
self,
name: str,
num_blocks: Optional[int],
fn: Callable[[BlockList, bool], Tuple[BlockList, dict]],
fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]],
supports_block_udf: bool = False,
block_udf=None,
):
super().__init__(name, num_blocks)
self.fn = fn
self.supports_block_udf = supports_block_udf
self.block_udf = block_udf

def __call__(
self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]:
blocks, stage_info = self.fn(blocks, clear_input_blocks)
blocks, stage_info = self.fn(blocks, clear_input_blocks, self.block_udf)
assert isinstance(blocks, BlockList), blocks
return blocks, stage_info
16 changes: 14 additions & 2 deletions python/ray/data/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import math
from typing import TypeVar, List, Optional, Dict, Any, Tuple, Union
from typing import TypeVar, List, Optional, Dict, Any, Tuple, Union, Callable, Iterable

import numpy as np

Expand All @@ -17,6 +17,7 @@

def simple_shuffle(
input_blocks: BlockList,
block_udf: Optional[Callable[[Block], Iterable[Block]]],
output_num_blocks: int,
*,
random_shuffle: bool = False,
Expand Down Expand Up @@ -55,7 +56,7 @@ def simple_shuffle(
**map_ray_remote_args,
num_returns=1 + output_num_blocks,
resources=next(map_resource_iter)
).remote(block, i, output_num_blocks, random_shuffle, random_seed)
).remote(block, block_udf, i, output_num_blocks, random_shuffle, random_seed)
for i, block in enumerate(input_blocks)
]

Expand Down Expand Up @@ -103,13 +104,24 @@ def simple_shuffle(

def _shuffle_map(
block: Block,
block_udf: Optional[Callable[[Block], Iterable[Block]]],
idx: int,
output_num_blocks: int,
random_shuffle: bool,
random_seed: Optional[int],
) -> List[Union[BlockMetadata, Block]]:
"""Returns list of [BlockMetadata, O1, O2, O3, ...output_num_blocks]."""
stats = BlockExecStats.builder()
if block_udf:
# TODO(ekl) note that this effectively disables block splitting.
pieces = list(block_udf(block))
if len(pieces) > 1:
builder = BlockAccessor.for_block(pieces[0]).builder()
for p in pieces:
builder.add_block(p)
block = builder.build()
else:
block = pieces[0]
block = BlockAccessor.for_block(block)

# Randomize the distribution of records to blocks.
Expand Down
4 changes: 4 additions & 0 deletions python/ray/data/impl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def sort_impl(
result = sort_block.remote(block, boundaries, key, descending)
map_results[i, :] = result[:-1]
map_meta.append(result[-1])

# Early release memory.
del blocks

map_bar = ProgressBar("Sort Map", len(map_results))
map_bar.block_until_complete(map_meta)
map_bar.close()
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/impl/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def build_multistage(

def build(self, final_blocks: BlockList) -> "DatasetStats":
stats = DatasetStats(
stages={self.stage_name: final_blocks.get_metadata()}, parent=self.parent
stages={self.stage_name: final_blocks.get_metadata()},
parent=self.parent,
)
stats.time_total_s = time.perf_counter() - self.start_time
return stats
Expand Down Expand Up @@ -205,7 +206,7 @@ def summary_string(self, already_printed: Set[str] = None) -> str:
out += p.summary_string(already_printed)
out += "\n"
first = True
for stage_name, metadata in sorted(self.stages.items()):
for stage_name, metadata in self.stages.items():
stage_uuid = self.dataset_uuid + stage_name
if first:
first = False
Expand Down

0 comments on commit 2158df3

Please sign in to comment.