Skip to content

Commit

Permalink
[data] Fix issues with combining use of materialize() and streaming_s…
Browse files Browse the repository at this point in the history
…plit() (#36092)

Fixes bugs:

Materialize isn't respected for streaming_split.
Multiple iterations over a materialized dataset results doesn't work correctly with streaming split.
  • Loading branch information
ericl authored Jun 8, 2023
1 parent fd70407 commit ceef7fd
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 16 deletions.
9 changes: 8 additions & 1 deletion python/ray/data/_internal/execution/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class RefBundle:
"""

# The size_bytes must be known in the metadata, num_rows is optional.
blocks: List[Tuple[ObjectRef[Block], BlockMetadata]]
blocks: Tuple[Tuple[ObjectRef[Block], BlockMetadata]]

# Whether we own the blocks (can safely destroy them).
owns_blocks: bool
Expand All @@ -49,6 +49,8 @@ class RefBundle:
_cached_location: Optional[NodeIdStr] = None

def __post_init__(self):
if not isinstance(self.blocks, tuple):
object.__setattr__(self, "blocks", tuple(self.blocks))
for b in self.blocks:
assert isinstance(b, tuple), b
assert len(b) == 2, b
Expand All @@ -59,6 +61,11 @@ def __post_init__(self):
"The size in bytes of the block must be known: {}".format(b)
)

def __setattr__(self, key, value):
if hasattr(self, key) and key in ["blocks", "owns_blocks"]:
raise ValueError(f"The `{key}` field of RefBundle cannot be updated.")
object.__setattr__(self, key, value)

def num_rows(self) -> Optional[int]:
"""Number of rows present in this bundle, if known."""
total = 0
Expand Down
12 changes: 9 additions & 3 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import ray
Expand Down Expand Up @@ -523,12 +523,15 @@ def get_next(self) -> RefBundle:
out_bundle = self._tasks_by_output_order[self._next_output_index].output
# Pop out the next single-block bundle.
next_bundle = RefBundle(
[out_bundle.blocks.pop(0)], owns_blocks=out_bundle.owns_blocks
[out_bundle.blocks[0]], owns_blocks=out_bundle.owns_blocks
)
out_bundle = replace(out_bundle, blocks=out_bundle.blocks[1:])
if not out_bundle.blocks:
# If this task's RefBundle is exhausted, move to the next one.
del self._tasks_by_output_order[self._next_output_index]
self._next_output_index += 1
else:
self._tasks_by_output_order[self._next_output_index].output = out_bundle
return next_bundle


Expand All @@ -549,11 +552,14 @@ def get_next(self) -> RefBundle:
out_bundle = self._completed_tasks[0].output
# Pop out the next single-block bundle.
next_bundle = RefBundle(
[out_bundle.blocks.pop(0)], owns_blocks=out_bundle.owns_blocks
[out_bundle.blocks[0]], owns_blocks=out_bundle.owns_blocks
)
out_bundle = replace(out_bundle, blocks=out_bundle.blocks[1:])
if not out_bundle.blocks:
# If this task's RefBundle is exhausted, move to the next one.
del self._completed_tasks[0]
else:
self._completed_tasks[0].output = out_bundle
return next_bundle


Expand Down
6 changes: 4 additions & 2 deletions python/ray/data/_internal/iterator/stream_split_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import threading
import time
from dataclasses import replace
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union

import ray
Expand Down Expand Up @@ -195,15 +196,16 @@ def get(
# This is a BLOCKING call, so do it outside the lock.
next_bundle = self._output_iterator.get_next(output_split_idx)

bundle = next_bundle.blocks.pop()
block = next_bundle.blocks[-1]
next_bundle = replace(next_bundle, blocks=next_bundle.blocks[:-1])

# Accumulate any remaining blocks in next_bundle map as needed.
with self._lock:
self._next_bundle[output_split_idx] = next_bundle
if not next_bundle.blocks:
del self._next_bundle[output_split_idx]

return bundle
return block
except StopIteration:
return None

Expand Down
20 changes: 14 additions & 6 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3988,12 +3988,20 @@ def materialize(self) -> "MaterializedDataset":
)
for block_with_metadata in blocks_with_metadata
]

# Create a new logical plan whose input is the existing data
# from the the old Dataset.
copy._logical_plan = LogicalPlan(InputData(input_data=ref_bundles))

return copy
logical_plan = LogicalPlan(InputData(input_data=ref_bundles))
output = MaterializedDataset(
ExecutionPlan(
blocks,
copy._plan.stats(),
run_by_consumer=False,
),
copy._epoch,
copy._lazy,
logical_plan,
)
output._plan.execute() # No-op that marks the plan as fully executed.
output._plan._in_stats.dataset_uuid = self._get_uuid()
return output

@ConsumptionAPI(pattern="timing information.", insert_after=True)
def stats(self) -> str:
Expand Down
7 changes: 7 additions & 0 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,18 @@ def inc(x):
assert isinstance(ds2, MaterializedDataset)
assert not ds.is_fully_executed()

# Tests standard iteration uses the materialized blocks.
for _ in range(10):
ds2.take_all()

assert ray.get(c.inc.remote()) == 2

# Tests streaming iteration uses the materialized blocks.
for _ in range(10):
list(ds2.streaming_split(1)[0].iter_batches())

assert ray.get(c.inc.remote()) == 3


def test_schema(ray_start_regular_shared):
ds2 = ray.data.range(10, parallelism=10)
Expand Down
5 changes: 2 additions & 3 deletions python/ray/data/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def expect_stages(pipe, num_stages_expected, stage_names):
name = " " + name + ":"
assert name in stats, (name, stats)
if isinstance(pipe, Dataset):
assert (
len(pipe._plan._stages_before_snapshot) == num_stages_expected
), pipe._plan._stages_before_snapshot
pass
else:
assert (
len(pipe._optimized_stages) == num_stages_expected
Expand Down Expand Up @@ -322,6 +320,7 @@ def test_optimize_reorder(ray_start_regular_shared):
context.optimize_reorder_stages = True

ds = ray.data.range(10).randomize_block_order().map_batches(dummy_map).materialize()
print("Stats", ds.stats())
expect_stages(
ds,
2,
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@ def test_dataset_pipeline_cache_cases(ray_start_regular_shared):
ds = ray.data.range(10).materialize().repeat(2).map_batches(lambda x: x)
ds.take(999)
stats = ds.stats()
print("STATS", stats)
assert "[execution cached]" in stats

# CACHED (eager map stage).
Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,10 @@ def test_torch_trainer_crash(ray_start_10_cpus_shared):
def train_loop_per_worker():
it = session.get_dataset_shard("train")
for i in range(2):
count = 0
for batch in it.iter_batches():
pass
count += len(batch["data"])
assert count == 50

my_trainer = TorchTrainer(
train_loop_per_worker,
Expand Down

0 comments on commit ceef7fd

Please sign in to comment.