32
32
import tensorflow as tf
33
33
import torch
34
34
35
- from ray .data ._internal .torch_iterable_dataset import TorchTensorBatchType
36
- from ray .data .dataset import Schema , TensorFlowTensorBatchType
35
+ from ray .data .dataset import (
36
+ CollatedData ,
37
+ Schema ,
38
+ TensorFlowTensorBatchType ,
39
+ TorchBatchType ,
40
+ )
37
41
38
42
39
43
@PublicAPI (stability = "beta" )
@@ -93,7 +97,7 @@ def iter_batches(
93
97
drop_last : bool = False ,
94
98
local_shuffle_buffer_size : Optional [int ] = None ,
95
99
local_shuffle_seed : Optional [int ] = None ,
96
- _collate_fn : Optional [Callable [[DataBatch ], Any ]] = None ,
100
+ _collate_fn : Optional [Callable [[DataBatch ], "CollatedData" ]] = None ,
97
101
_finalize_fn : Optional [Callable [[Any ], Any ]] = None ,
98
102
# Deprecated.
99
103
prefetch_blocks : int = 0 ,
@@ -255,29 +259,48 @@ def iter_torch_batches(
255
259
prefetch_batches : int = 1 ,
256
260
batch_size : Optional [int ] = 256 ,
257
261
dtypes : Optional [Union ["torch.dtype" , Dict [str , "torch.dtype" ]]] = None ,
258
- device : Optional [str ] = None ,
259
- collate_fn : Optional [
260
- Callable [[Union [np .ndarray , Dict [str , np .ndarray ]]], Any ]
261
- ] = None ,
262
+ device : str = "auto" ,
263
+ collate_fn : Optional [Callable [[Dict [str , np .ndarray ]], "CollatedData" ]] = None ,
262
264
drop_last : bool = False ,
263
265
local_shuffle_buffer_size : Optional [int ] = None ,
264
266
local_shuffle_seed : Optional [int ] = None ,
265
267
# Deprecated.
266
268
prefetch_blocks : int = 0 ,
267
- ) -> Iterator ["TorchTensorBatchType " ]:
269
+ ) -> Iterator ["TorchBatchType " ]:
268
270
"""Return a batched iterator of Torch Tensors over the dataset.
269
271
270
- This iterator will yield single-tensor batches if the underlying dataset
271
- consists of a single column; otherwise, it will yield a dictionary of
272
- column-tensors. If looking for more flexibility in the tensor conversion (e.g.
273
- casting dtypes) or the batch format, try using `.iter_batches` directly.
272
+ This iterator yields a dictionary of column-tensors. If you are looking for
273
+ more flexibility in the tensor conversion (e.g. casting dtypes) or the batch
274
+ format, try using :meth:`~ray.data.iterator.DataIterator.iter_batches` directly.
274
275
275
276
Examples:
276
277
>>> import ray
277
- >>> for row in ray.data.range(
278
- ... 1000000
279
- ... ).iterator().iter_rows(): # doctest: +SKIP
280
- ... print(row) # doctest: +SKIP
278
+ >>> for batch in ray.data.range(
279
+ ... 12,
280
+ ... ).iterator().iter_torch_batches(batch_size=4):
281
+ ... print(batch)
282
+ {'id': tensor([0, 1, 2, 3])}
283
+ {'id': tensor([4, 5, 6, 7])}
284
+ {'id': tensor([ 8, 9, 10, 11])}
285
+
286
+ Use the ``collate_fn`` to customize how the tensor batch is created.
287
+
288
+ >>> from typing import Any, Dict
289
+ >>> import torch
290
+ >>> import numpy as np
291
+ >>> import ray
292
+ >>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
293
+ ... return torch.stack(
294
+ ... [torch.as_tensor(array) for array in batch.values()],
295
+ ... axis=1
296
+ ... )
297
+ >>> iterator = ray.data.from_items([
298
+ ... {"col_1": 1, "col_2": 2},
299
+ ... {"col_1": 3, "col_2": 4}]).iterator()
300
+ >>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn):
301
+ ... print(batch)
302
+ tensor([[1, 2],
303
+ [3, 4]])
281
304
282
305
Time complexity: O(1)
283
306
@@ -293,17 +316,24 @@ def iter_torch_batches(
293
316
The final batch may include fewer than ``batch_size`` rows if
294
317
``drop_last`` is ``False``. Defaults to 256.
295
318
dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype
296
- will be inferred from the tensor data.
297
- device: The device on which the tensor should be placed; if None, the Torch
298
- tensor will be constructed on the CPU.
299
- collate_fn: A function to apply to each data batch before returning it. When
300
- this parameter is specified, the user should manually handle the host
301
- to device data transfer outside of collate_fn. Potential use cases
302
- include collating along a dimension other than the first, padding
303
- sequences of various lengths, or generally handling batches of different
304
- length tensors. This API is still experimental and is subject to change.
305
- This parameter cannot be used in conjunction with ``dtypes`` or
306
- ``device``.
319
+ will be inferred from the tensor data. You can't use this parameter
320
+ with ``collate_fn``.
321
+ device: The device on which the tensor should be placed. Defaults to
322
+ "auto" which moves the tensors to the appropriate device when the
323
+ Dataset is passed to Ray Train and ``collate_fn`` is not provided.
324
+ Otherwise, defaults to CPU. You can't use this parameter with
325
+ ``collate_fn``.
326
+ collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
327
+ When this parameter is specified, the user should manually handle the
328
+ host to device data transfer outside of ``collate_fn``.
329
+ This is useful for further processing the data after it has been
330
+ batched. Potential use cases include collating along a dimension other
331
+ than the first, padding sequences of various lengths, or generally
332
+ handling batches of different length tensors. If not provided, the
333
+ default collate function is used which simply converts the batch of
334
+ numpy arrays to a batch of PyTorch tensors. This API is still
335
+ experimental and is subject to change. You can't use this parameter in
336
+ conjunction with ``dtypes`` or ``device``.
307
337
drop_last: Whether to drop the last batch if it's incomplete.
308
338
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
309
339
using a local in-memory shuffle buffer, and this value will serve as the
@@ -324,20 +354,19 @@ def iter_torch_batches(
324
354
get_device ,
325
355
)
326
356
327
- if collate_fn is not None and (dtypes is not None or device is not None ):
357
+ if collate_fn is not None and (dtypes is not None or device != "auto" ):
328
358
raise ValueError (
329
359
"collate_fn cannot be used with dtypes and device."
330
360
"You should manually move the output Torch tensors to the"
331
361
"desired dtype and device outside of collate_fn."
332
362
)
333
363
334
- if collate_fn is None :
335
- # Automatically move torch tensors to the appropriate device.
336
- if device is None :
337
- default_device = get_device ()
338
- if default_device .type != "cpu" :
339
- device = default_device
364
+ if device == "auto" :
365
+ # Use the appropriate device for Ray Train, or falls back to CPU if
366
+ # Ray Train is not being used.
367
+ device = get_device ()
340
368
369
+ if collate_fn is None :
341
370
# The default collate_fn handles formatting and Tensor creation.
342
371
# Here, we set device=None to defer host to device data transfer
343
372
# to the subsequent finalize_fn.
0 commit comments