diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e74b9060943a7..85c4a31e90f07 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1023,14 +1023,24 @@ def format_batch(batch: Block, format: str) -> BatchType: f"is invalid. Supported batch type: {BatchType}") batcher = Batcher(batch_size=batch_size) - for block_window in sliding_window(self._blocks, prefetch_blocks + 1): - block_window = list(block_window) - ray.wait(block_window, num_returns=1, fetch_local=True) - block = ray.get(block_window[0]) + + def batch_block(block: ObjectRef[Block]): + block = ray.get(block) batcher.add(block) while batcher.has_batch(): yield format_batch(batcher.next_batch(), batch_format) + block_window = [] # Handle empty sliding window gracefully. + for block_window in sliding_window(self._blocks, prefetch_blocks + 1): + block_window = list(block_window) + ray.wait(block_window, num_returns=1, fetch_local=True) + yield from batch_block(block_window[0]) + + # Consume remainder of final block window. + for block in block_window[1:]: + yield from batch_block(block) + + # Yield any remainder batches. if batcher.has_any() and not drop_last: yield format_batch(batcher.next_batch(), batch_format) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 7c984270125b8..633e3b84e45ab 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1325,6 +1325,16 @@ def test_iter_batches_basic(ray_start_regular_shared): assert isinstance(batch, pd.DataFrame) assert batch.equals(df) + batch_size = 2 + batches = list( + ds.iter_batches( + prefetch_blocks=2, batch_size=batch_size, batch_format="pandas")) + assert all(len(batch) == batch_size for batch in batches) + assert (len(batches) == math.ceil( + (len(df1) + len(df2) + len(df3) + len(df4)) / batch_size)) + assert pd.concat( + batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True)) + def test_iter_batches_grid(ray_start_regular_shared): # Tests slicing, batch combining, and partial batch dropping logic over