Skip to content

Commit

Permalink
Fix iter_batches dropping batches when prefetching.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Sep 8, 2021
1 parent 6011d41 commit 4ea72cc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
18 changes: 14 additions & 4 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,14 +1023,24 @@ def format_batch(batch: Block, format: str) -> BatchType:
f"is invalid. Supported batch type: {BatchType}")

batcher = Batcher(batch_size=batch_size)
for block_window in sliding_window(self._blocks, prefetch_blocks + 1):
block_window = list(block_window)
ray.wait(block_window, num_returns=1, fetch_local=True)
block = ray.get(block_window[0])

def batch_block(block: ObjectRef[Block]):
block = ray.get(block)
batcher.add(block)
while batcher.has_batch():
yield format_batch(batcher.next_batch(), batch_format)

block_window = [] # Handle empty sliding window gracefully.
for block_window in sliding_window(self._blocks, prefetch_blocks + 1):
block_window = list(block_window)
ray.wait(block_window, num_returns=1, fetch_local=True)
yield from batch_block(block_window[0])

# Consume remainder of final block window.
for block in block_window[1:]:
yield from batch_block(block)

# Yield any remainder batches.
if batcher.has_any() and not drop_last:
yield format_batch(batcher.next_batch(), batch_format)

Expand Down
10 changes: 10 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,16 @@ def test_iter_batches_basic(ray_start_regular_shared):
assert isinstance(batch, pd.DataFrame)
assert batch.equals(df)

batch_size = 2
batches = list(
ds.iter_batches(
prefetch_blocks=2, batch_size=batch_size, batch_format="pandas"))
assert all(len(batch) == batch_size for batch in batches)
assert (len(batches) == math.ceil(
(len(df1) + len(df2) + len(df3) + len(df4)) / batch_size))
assert pd.concat(
batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True))


def test_iter_batches_grid(ray_start_regular_shared):
# Tests slicing, batch combining, and partial batch dropping logic over
Expand Down

0 comments on commit 4ea72cc

Please sign in to comment.