Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Fix iter_batches dropping batches when prefetching. #18441

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,14 +1023,24 @@ def format_batch(batch: Block, format: str) -> BatchType:
f"is invalid. Supported batch type: {BatchType}")

batcher = Batcher(batch_size=batch_size)
for block_window in sliding_window(self._blocks, prefetch_blocks + 1):
block_window = list(block_window)
ray.wait(block_window, num_returns=1, fetch_local=True)
block = ray.get(block_window[0])

def batch_block(block: ObjectRef[Block]):
block = ray.get(block)
batcher.add(block)
while batcher.has_batch():
yield format_batch(batcher.next_batch(), batch_format)

block_window = [] # Handle empty sliding window gracefully.
for block_window in sliding_window(self._blocks, prefetch_blocks + 1):
block_window = list(block_window)
ray.wait(block_window, num_returns=1, fetch_local=True)
yield from batch_block(block_window[0])

# Consume remainder of final block window.
for block in block_window[1:]:
yield from batch_block(block)

# Yield any remainder batches.
if batcher.has_any() and not drop_last:
yield format_batch(batcher.next_batch(), batch_format)

Expand Down
10 changes: 10 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,16 @@ def test_iter_batches_basic(ray_start_regular_shared):
assert isinstance(batch, pd.DataFrame)
assert batch.equals(df)

batch_size = 2
batches = list(
ds.iter_batches(
prefetch_blocks=2, batch_size=batch_size, batch_format="pandas"))
assert all(len(batch) == batch_size for batch in batches)
assert (len(batches) == math.ceil(
(len(df1) + len(df2) + len(df3) + len(df4)) / batch_size))
assert pd.concat(
batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True))


def test_iter_batches_grid(ray_start_regular_shared):
# Tests slicing, batch combining, and partial batch dropping logic over
Expand Down