-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR/Train] Make Dataset ingest configurable #24066
Changes from 1 commit
5275cbe
7e7c927
13537ca
8783055
e2348ba
cf9a077
7c345ba
8b25c73
9bb124e
d5819c7
12664c3
6cdc8f0
8ab12da
8ed0843
90ef2fa
31d0501
f1e4e31
f42044a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from ray.train import BackendConfig, TrainingIterator | ||
from ray.train.backend import BackendExecutor | ||
from ray.train.checkpoint import TuneCheckpointManager | ||
from ray.train.utils import construct_train_func | ||
from ray.train.utils import construct_train_func, RayDatasetSpec | ||
from ray.util.annotations import DeveloperAPI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -292,27 +292,36 @@ def training_loop(self) -> None: | |
else: | ||
resume_checkpoint_dict = None | ||
|
||
# Tell Ray Train to only shard the train dataset and not the other datasets. | ||
# This is purely an implementation detail and users do not need to know about | ||
# this. | ||
# TODO(amog): Refactor this to remove hack and make this more modular. | ||
# TrainingIterator should accept a generic custom_ingest_func that contains | ||
# the logic for how to split the Datasets. | ||
updated_dataset_dict = {} | ||
for key, value in self.datasets.items(): | ||
if key == TRAIN_DATASET_KEY: | ||
updated_dataset_dict[key] = value | ||
else: | ||
# Ray Train will strip out the added string before exposing to users. | ||
updated_dataset_dict[key + "_NO-SHARD"] = value | ||
def dataset_split_fn(dataset_dict, training_worker_handles): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a type annotation to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! |
||
dataset_dict_splits = [{} for _ in range(len(training_worker_handles))] | ||
|
||
for key, dataset in dataset_dict.items(): | ||
if key == TRAIN_DATASET_KEY: | ||
dataset_splits = dataset.split( | ||
len(training_worker_handles), | ||
equal=True, | ||
locality_hints=training_worker_handles, | ||
) | ||
else: | ||
# Only shard the training dataset. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment explaining why we're only sharding the training dataset? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! |
||
dataset_splits = [dataset] * training_worker_handles | ||
|
||
for i in range(len(dataset_splits)): | ||
dataset_dict_splits[i][key] = dataset_splits[i] | ||
|
||
return dataset_dict_splits | ||
|
||
dataset_spec = RayDatasetSpec( | ||
dataset_or_dict=self.datasets, dataset_split_fn=dataset_split_fn | ||
) | ||
|
||
# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead | ||
# of just a Dict. | ||
training_iterator = TrainingIterator( | ||
backend_executor=backend_executor, | ||
backend_config=self.backend_config, | ||
train_func=train_loop_per_worker, | ||
dataset=updated_dataset_dict if len(updated_dataset_dict) > 0 else None, | ||
dataset_spec=dataset_spec, | ||
checkpoint_manager=checkpoint_manager, | ||
checkpoint=resume_checkpoint_dict, | ||
checkpoint_strategy=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import abc | ||
from dataclasses import dataclass | ||
import inspect | ||
import os | ||
import logging | ||
|
@@ -172,3 +173,88 @@ def __getattr__(self, item): | |
# actor. | ||
actor_method = getattr(self.actor, item) | ||
return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs)) | ||
|
||
|
||
@dataclass | ||
class RayDatasetSpec: | ||
amogkam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Configuration for Ray Datasets to pass to the training workers. | ||
|
||
dataset_or_dict: An optional Ray Dataset (or DatasetPipeline) or a dictionary of | ||
datasets to be sharded across all the training workers, which can be accessed | ||
from the training function via ``train.get_dataset_shard()``. Multiple Datasets | ||
can be passed in as a ``Dict`` that maps each name key to a Dataset value, | ||
amogkam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and each Dataset can be accessed from the training function by passing in a | ||
`dataset_name` argument to ``train.get_dataset_shard()``. | ||
dataset_split_fn: An optional callbale to specify how the provided ``dataset`` | ||
amogkam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
should be split across the training workers. It is expected to take in two | ||
arguments. The first one is the ``dataset``, just as is passed in to the | ||
``RayDatasetSpec``. The second argument is a list of the ActorHandles of the | ||
training workers (to use as locality hints). The Callable is expected to | ||
return a list of RayDatasets or a list of dictionaries of RayDatasets, | ||
with the length of the list equal to the length of the list of actor handles. | ||
If None is provided, the provided Ray Dataset(s) will be simply be split using | ||
the actor handles as locality hints. | ||
|
||
""" | ||
|
||
dataset_or_dict: Optional[Union[RayDataset, Dict[str, RayDataset]]] | ||
dataset_split_fn: Optional[ | ||
Callable[ | ||
[Union[RayDataset, Dict[str, RayDataset]], List[ActorHandle]], | ||
List[Union[RayDataset, Dict[str, RayDataset]]], | ||
] | ||
] = None | ||
|
||
def _default_split_fn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused why we need both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More specifically, the default implementation for
In the future, users should be able to override the behavior for |
||
self, training_worker_handles: List[ActorHandle] | ||
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]: | ||
def split_dataset(dataset_or_pipeline): | ||
return dataset_or_pipeline.split( | ||
len(training_worker_handles), | ||
equal=True, | ||
locality_hints=training_worker_handles, | ||
) | ||
|
||
if isinstance(self.dataset_or_dict, dict): | ||
# Return a smaller dict for each shard. | ||
dataset_shards = [{} for _ in range(len(self.worker_group))] | ||
for key, dataset in self.dataset_or_dict.items(): | ||
split_datasets = split_dataset(dataset) | ||
assert len(split_datasets) == len(self.worker_group) | ||
for i in range(len(split_datasets)): | ||
dataset_shards[i][key] = split_datasets[i] | ||
return dataset_shards | ||
else: | ||
# return a smaller RayDataset for each shard. | ||
return split_dataset(self.dataset_or_dict) | ||
|
||
def get_dataset_shards( | ||
self, training_worker_handles: List[ActorHandle] | ||
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]: | ||
"""Returns Dataset splits based off the spec and the given training workers | ||
|
||
Args: | ||
training_worker_handles: A list of the training worker actor handles. | ||
|
||
Returns: | ||
A list of RayDataset shards or list of dictionaries of RayDataset shards, | ||
one for each training worker. | ||
|
||
""" | ||
if self.dataset_or_dict is None: | ||
# If no Dataset is provided, return None for each shard. | ||
return [None] * len(training_worker_handles) | ||
|
||
if self.dataset_split_fn is None: | ||
return self._default_split_fn(training_worker_handles) | ||
else: | ||
splits = self.dataset_split_fn( | ||
self.dataset_or_dict, training_worker_handles | ||
) | ||
if not len(splits) == len(training_worker_handles): | ||
raise RuntimeError( | ||
"The list of Datasets returned by the " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about moving this class into a separate file, such as (also, this should go into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on moving dataset spec to it's own module. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved to its own module! But kept it as part of |
||
f"`dataset_split_fn`: {len(splits)} does not match " | ||
f"the number of training workers: {len(training_worker_handles)}" | ||
) | ||
return splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we pull this out into a default splitting function in the util module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're pulling this out into a default splitting function, could you add a docstring? Would allow readers to understand the function without having to reference
_RayDatasetSpec
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated into its own function, but left it in
data_parallel_trainer
for now as it is only being used inDataParallelTrainer
. Let's revisit the location if we need it in more trainers in the future.