Skip to content

Commit 8e4b6b0

Browse files
authored
[Data] iter_torch_batches updates (ray-project#37625)
Adds more documentation to iter_torch_batches docstring. Changes the default value of device parameter to "auto" to make the behavior of automatic device transfer more explicit. --------- Signed-off-by: amogkam <amogkamsetty@yahoo.com>
1 parent 5ce25a5 commit 8e4b6b0

File tree

3 files changed

+116
-72
lines changed

3 files changed

+116
-72
lines changed

python/ray/data/_internal/torch_iterable_dataset.py

-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
1-
from typing import TYPE_CHECKING, Dict, Union
2-
31
from torch.utils.data import IterableDataset
42

5-
if TYPE_CHECKING:
6-
import torch
7-
8-
9-
TorchTensorBatchType = Union["torch.Tensor", Dict[str, "torch.Tensor"]]
10-
113

124
class TorchIterableDataset(IterableDataset):
135
def __init__(self, generator_func):

python/ray/data/dataset.py

+53-30
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Optional,
1919
Tuple,
2020
Type,
21+
TypeVar,
2122
Union,
2223
)
2324
from uuid import uuid4
@@ -152,7 +153,6 @@
152153
from tensorflow_metadata.proto.v0 import schema_pb2
153154

154155
from ray.data._internal.execution.interfaces import Executor, NodeIdStr
155-
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
156156
from ray.data.dataset_pipeline import DatasetPipeline
157157
from ray.data.grouped_data import GroupedData
158158

@@ -165,6 +165,9 @@
165165

166166
TensorFlowTensorBatchType = Union["tf.Tensor", Dict[str, "tf.Tensor"]]
167167

168+
CollatedData = TypeVar("CollatedData")
169+
TorchBatchType = Union[Dict[str, "torch.Tensor"], CollatedData]
170+
168171

169172
@PublicAPI
170173
class Dataset:
@@ -3336,7 +3339,7 @@ def iter_batches(
33363339
drop_last: bool = False,
33373340
local_shuffle_buffer_size: Optional[int] = None,
33383341
local_shuffle_seed: Optional[int] = None,
3339-
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
3342+
_collate_fn: Optional[Callable[[DataBatch], CollatedData]] = None,
33403343
# Deprecated.
33413344
prefetch_blocks: int = 0,
33423345
) -> Iterator[DataBatch]:
@@ -3408,37 +3411,49 @@ def iter_torch_batches(
34083411
prefetch_batches: int = 1,
34093412
batch_size: Optional[int] = 256,
34103413
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
3411-
device: Optional[str] = None,
3412-
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], Any]] = None,
3414+
device: str = "auto",
3415+
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], CollatedData]] = None,
34133416
drop_last: bool = False,
34143417
local_shuffle_buffer_size: Optional[int] = None,
34153418
local_shuffle_seed: Optional[int] = None,
34163419
# Deprecated
34173420
prefetch_blocks: int = 0,
3418-
) -> Iterator["TorchTensorBatchType"]:
3421+
) -> Iterator[TorchBatchType]:
34193422
"""Return an iterator over batches of data represented as Torch tensors.
34203423
34213424
This iterator yields batches of type ``Dict[str, torch.Tensor]``.
34223425
For more flexibility, call :meth:`~Dataset.iter_batches` and manually convert
34233426
your data to Torch tensors.
34243427
34253428
Examples:
3429+
>>> import ray
3430+
>>> for batch in ray.data.range(
3431+
... 12,
3432+
... ).iter_torch_batches(batch_size=4):
3433+
... print(batch)
3434+
{'id': tensor([0, 1, 2, 3])}
3435+
{'id': tensor([4, 5, 6, 7])}
3436+
{'id': tensor([ 8, 9, 10, 11])}
3437+
3438+
Use the ``collate_fn`` to customize how the tensor batch is created.
3439+
3440+
>>> from typing import Any, Dict
3441+
>>> import torch
3442+
>>> import numpy as np
3443+
>>> import ray
3444+
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
3445+
... return torch.stack(
3446+
... [torch.as_tensor(array) for array in batch.values()],
3447+
... axis=1
3448+
... )
3449+
>>> dataset = ray.data.from_items([
3450+
... {"col_1": 1, "col_2": 2},
3451+
... {"col_1": 3, "col_2": 4}])
3452+
>>> for batch in dataset.iter_torch_batches(collate_fn=collate_fn):
3453+
... print(batch)
3454+
tensor([[1, 2],
3455+
[3, 4]])
34263456
3427-
.. testcode::
3428-
3429-
import ray
3430-
3431-
# This dataset contains three images.
3432-
ds = ray.data.read_images("example://image-datasets/simple")
3433-
3434-
for batch in ds.iter_torch_batches(batch_size=2):
3435-
print(batch)
3436-
3437-
.. testoutput::
3438-
:options: +MOCK
3439-
3440-
{'image': <tf.Tensor: shape=(2, 32, 32, 3), dtype=uint8, numpy=array([[[[...]]]], dtype=uint8)>}
3441-
{'image': <tf.Tensor: shape=(1, 32, 32, 3), dtype=uint8, numpy=array([[[[...]]]], dtype=uint8)>}
34423457
34433458
Time complexity: O(1)
34443459
@@ -3451,17 +3466,25 @@ def iter_torch_batches(
34513466
blocks as batches (blocks may contain different number of rows).
34523467
The final batch may include fewer than ``batch_size`` rows if
34533468
``drop_last`` is ``False``. Defaults to 256.
3454-
dtypes: The Torch dtype(s) for the created tensor(s); if ``None``, the
3455-
dtype is inferred from the tensor data.
3456-
device: The device on which the tensor should be placed; if ``None``, the
3457-
Torch tensor is constructed on CPU.
3469+
dtypes: The Torch dtype(s) for the created tensor(s); if ``None``, the dtype
3470+
is inferred from the tensor data. You can't use this parameter with
3471+
``collate_fn``.
3472+
device: The device on which the tensor should be placed. Defaults to
3473+
"auto" which moves the tensors to the appropriate device when the
3474+
Dataset is passed to Ray Train and ``collate_fn`` is not provided.
3475+
Otherwise, defaults to CPU. You can't use this parameter with
3476+
``collate_fn``.
34583477
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
3459-
Potential use cases include collating along a dimension other than the
3460-
first, padding sequences of various lengths, or generally handling
3461-
batches of different length tensors. If not provided, the default
3462-
collate function is used which simply converts the batch of numpy
3463-
arrays to a batch of PyTorch tensors. This API is still experimental
3464-
and is subject to change.
3478+
When this parameter is specified, the user should manually handle the
3479+
host to device data transfer outside of collate_fn.
3480+
This is useful for further processing the data after it has been
3481+
batched. Potential use cases include collating along a dimension other
3482+
than the first, padding sequences of various lengths, or generally
3483+
handling batches of different length tensors. If not provided, the
3484+
default collate function is used which simply converts the batch of
3485+
numpy arrays to a batch of PyTorch tensors. This API is still
3486+
experimental and is subject to change. You can't use this parameter in
3487+
conjunction with ``dtypes`` or ``device``.
34653488
drop_last: Whether to drop the last batch if it's incomplete.
34663489
local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled
34673490
using a local in-memory shuffle buffer, and this value serves as the

python/ray/data/iterator.py

+63-34
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@
3232
import tensorflow as tf
3333
import torch
3434

35-
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
36-
from ray.data.dataset import Schema, TensorFlowTensorBatchType
35+
from ray.data.dataset import (
36+
CollatedData,
37+
Schema,
38+
TensorFlowTensorBatchType,
39+
TorchBatchType,
40+
)
3741

3842

3943
@PublicAPI(stability="beta")
@@ -93,7 +97,7 @@ def iter_batches(
9397
drop_last: bool = False,
9498
local_shuffle_buffer_size: Optional[int] = None,
9599
local_shuffle_seed: Optional[int] = None,
96-
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
100+
_collate_fn: Optional[Callable[[DataBatch], "CollatedData"]] = None,
97101
_finalize_fn: Optional[Callable[[Any], Any]] = None,
98102
# Deprecated.
99103
prefetch_blocks: int = 0,
@@ -255,29 +259,48 @@ def iter_torch_batches(
255259
prefetch_batches: int = 1,
256260
batch_size: Optional[int] = 256,
257261
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
258-
device: Optional[str] = None,
259-
collate_fn: Optional[
260-
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
261-
] = None,
262+
device: str = "auto",
263+
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None,
262264
drop_last: bool = False,
263265
local_shuffle_buffer_size: Optional[int] = None,
264266
local_shuffle_seed: Optional[int] = None,
265267
# Deprecated.
266268
prefetch_blocks: int = 0,
267-
) -> Iterator["TorchTensorBatchType"]:
269+
) -> Iterator["TorchBatchType"]:
268270
"""Return a batched iterator of Torch Tensors over the dataset.
269271
270-
This iterator will yield single-tensor batches if the underlying dataset
271-
consists of a single column; otherwise, it will yield a dictionary of
272-
column-tensors. If looking for more flexibility in the tensor conversion (e.g.
273-
casting dtypes) or the batch format, try using `.iter_batches` directly.
272+
This iterator yields a dictionary of column-tensors. If you are looking for
273+
more flexibility in the tensor conversion (e.g. casting dtypes) or the batch
274+
format, try using :meth:`~ray.data.iterator.DataIterator.iter_batches` directly.
274275
275276
Examples:
276277
>>> import ray
277-
>>> for row in ray.data.range(
278-
... 1000000
279-
... ).iterator().iter_rows(): # doctest: +SKIP
280-
... print(row) # doctest: +SKIP
278+
>>> for batch in ray.data.range(
279+
... 12,
280+
... ).iterator().iter_torch_batches(batch_size=4):
281+
... print(batch)
282+
{'id': tensor([0, 1, 2, 3])}
283+
{'id': tensor([4, 5, 6, 7])}
284+
{'id': tensor([ 8, 9, 10, 11])}
285+
286+
Use the ``collate_fn`` to customize how the tensor batch is created.
287+
288+
>>> from typing import Any, Dict
289+
>>> import torch
290+
>>> import numpy as np
291+
>>> import ray
292+
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
293+
... return torch.stack(
294+
... [torch.as_tensor(array) for array in batch.values()],
295+
... axis=1
296+
... )
297+
>>> iterator = ray.data.from_items([
298+
... {"col_1": 1, "col_2": 2},
299+
... {"col_1": 3, "col_2": 4}]).iterator()
300+
>>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn):
301+
... print(batch)
302+
tensor([[1, 2],
303+
[3, 4]])
281304
282305
Time complexity: O(1)
283306
@@ -293,17 +316,24 @@ def iter_torch_batches(
293316
The final batch may include fewer than ``batch_size`` rows if
294317
``drop_last`` is ``False``. Defaults to 256.
295318
dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype
296-
will be inferred from the tensor data.
297-
device: The device on which the tensor should be placed; if None, the Torch
298-
tensor will be constructed on the CPU.
299-
collate_fn: A function to apply to each data batch before returning it. When
300-
this parameter is specified, the user should manually handle the host
301-
to device data transfer outside of collate_fn. Potential use cases
302-
include collating along a dimension other than the first, padding
303-
sequences of various lengths, or generally handling batches of different
304-
length tensors. This API is still experimental and is subject to change.
305-
This parameter cannot be used in conjunction with ``dtypes`` or
306-
``device``.
319+
will be inferred from the tensor data. You can't use this parameter
320+
with ``collate_fn``.
321+
device: The device on which the tensor should be placed. Defaults to
322+
"auto" which moves the tensors to the appropriate device when the
323+
Dataset is passed to Ray Train and ``collate_fn`` is not provided.
324+
Otherwise, defaults to CPU. You can't use this parameter with
325+
``collate_fn``.
326+
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
327+
When this parameter is specified, the user should manually handle the
328+
host to device data transfer outside of ``collate_fn``.
329+
This is useful for further processing the data after it has been
330+
batched. Potential use cases include collating along a dimension other
331+
than the first, padding sequences of various lengths, or generally
332+
handling batches of different length tensors. If not provided, the
333+
default collate function is used which simply converts the batch of
334+
numpy arrays to a batch of PyTorch tensors. This API is still
335+
experimental and is subject to change. You can't use this parameter in
336+
conjunction with ``dtypes`` or ``device``.
307337
drop_last: Whether to drop the last batch if it's incomplete.
308338
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
309339
using a local in-memory shuffle buffer, and this value will serve as the
@@ -324,20 +354,19 @@ def iter_torch_batches(
324354
get_device,
325355
)
326356

327-
if collate_fn is not None and (dtypes is not None or device is not None):
357+
if collate_fn is not None and (dtypes is not None or device != "auto"):
328358
raise ValueError(
329359
"collate_fn cannot be used with dtypes and device."
330360
"You should manually move the output Torch tensors to the"
331361
"desired dtype and device outside of collate_fn."
332362
)
333363

334-
if collate_fn is None:
335-
# Automatically move torch tensors to the appropriate device.
336-
if device is None:
337-
default_device = get_device()
338-
if default_device.type != "cpu":
339-
device = default_device
364+
if device == "auto":
365+
# Use the appropriate device for Ray Train, or falls back to CPU if
366+
# Ray Train is not being used.
367+
device = get_device()
340368

369+
if collate_fn is None:
341370
# The default collate_fn handles formatting and Tensor creation.
342371
# Here, we set device=None to defer host to device data transfer
343372
# to the subsequent finalize_fn.

0 commit comments

Comments
 (0)