-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR/Train] Make Dataset ingest configurable #24066
Merged
amogkam
merged 18 commits into
ray-project:master
from
amogkam:air-dataset-split-refactor
Apr 28, 2022
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
5275cbe
refactor dataset splitting
amogkam 7e7c927
Update python/ray/train/utils.py
amogkam 13537ca
make private
amogkam 8783055
Merge branch 'air-dataset-split-refactor' of github.com:amogkam/ray i…
amogkam e2348ba
separate to own module
amogkam cf9a077
add file
amogkam 7c345ba
format
amogkam 8b25c73
Revert "separate to own module"
amogkam 9bb124e
fix
amogkam d5819c7
separate function
amogkam 12664c3
move to impl
amogkam 6cdc8f0
fix
amogkam 8ab12da
fix
amogkam 8ed0843
fix
amogkam 90ef2fa
Merge branch 'master' of https://github.com/ray-project/ray into air-…
amogkam 31d0501
Merge branch 'master' of https://github.com/ray-project/ray into air-…
amogkam f1e4e31
fix
amogkam f42044a
update tests
amogkam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Union, Dict, Callable, List, TYPE_CHECKING | ||
|
||
from ray.actor import ActorHandle | ||
|
||
if TYPE_CHECKING: | ||
from ray.data import Dataset, DatasetPipeline | ||
|
||
RayDataset = Union["Dataset", "DatasetPipeline"] | ||
|
||
|
||
@dataclass | ||
class _RayDatasetSpec: | ||
"""Configuration for Ray Datasets to pass to the training workers. | ||
|
||
dataset_or_dict: An optional Ray Dataset (or DatasetPipeline) or a dictionary of | ||
datasets to be sharded across all the training workers, which can be accessed | ||
from the training function via ``train.get_dataset_shard()``. Multiple Datasets | ||
can be passed in as a dictionary that maps each name key to a Dataset value, | ||
and each Dataset can be accessed from the training function by passing in a | ||
`dataset_name` argument to ``train.get_dataset_shard()``. | ||
dataset_split_fn: An optional callable to specify how the provided ``dataset`` | ||
should be split across the training workers. It is expected to take in two | ||
arguments. The first one is the ``dataset``, just as is passed in to the | ||
``_RayDatasetSpec``. The second argument is a list of the ActorHandles of the | ||
training workers (to use as locality hints). The Callable is expected to | ||
return a list of RayDatasets or a list of dictionaries of RayDatasets, | ||
with the length of the list equal to the length of the list of actor handles. | ||
If None is provided, the provided Ray Dataset(s) will be simply be split using | ||
the actor handles as locality hints. | ||
|
||
""" | ||
|
||
dataset_or_dict: Optional[Union[RayDataset, Dict[str, RayDataset]]] | ||
dataset_split_fn: Optional[ | ||
Callable[ | ||
[Union[RayDataset, Dict[str, RayDataset]], List[ActorHandle]], | ||
List[Union[RayDataset, Dict[str, RayDataset]]], | ||
] | ||
] = None | ||
|
||
def _default_split_fn( | ||
self, training_worker_handles: List[ActorHandle] | ||
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]: | ||
def split_dataset(dataset_or_pipeline): | ||
return dataset_or_pipeline.split( | ||
len(training_worker_handles), | ||
equal=True, | ||
locality_hints=training_worker_handles, | ||
) | ||
|
||
if isinstance(self.dataset_or_dict, dict): | ||
# Return a smaller dict for each shard. | ||
dataset_shards = [{} for _ in range(len(training_worker_handles))] | ||
for key, dataset in self.dataset_or_dict.items(): | ||
split_datasets = split_dataset(dataset) | ||
assert len(split_datasets) == len(training_worker_handles) | ||
for i in range(len(split_datasets)): | ||
dataset_shards[i][key] = split_datasets[i] | ||
return dataset_shards | ||
else: | ||
# return a smaller RayDataset for each shard. | ||
return split_dataset(self.dataset_or_dict) | ||
|
||
def get_dataset_shards( | ||
self, training_worker_handles: List[ActorHandle] | ||
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]: | ||
"""Returns Dataset splits based off the spec and the given training workers | ||
|
||
Args: | ||
training_worker_handles: A list of the training worker actor handles. | ||
|
||
Returns: | ||
A list of RayDataset shards or list of dictionaries of RayDataset shards, | ||
one for each training worker. | ||
|
||
""" | ||
if not self.dataset_or_dict: | ||
return [None] * len(training_worker_handles) | ||
|
||
if self.dataset_split_fn is None: | ||
return self._default_split_fn(training_worker_handles) | ||
else: | ||
splits = self.dataset_split_fn( | ||
self.dataset_or_dict, training_worker_handles | ||
) | ||
if not len(splits) == len(training_worker_handles): | ||
raise RuntimeError( | ||
"The list of Datasets returned by the " | ||
f"`dataset_split_fn`: {len(splits)} does not match " | ||
f"the number of training workers: {len(training_worker_handles)}" | ||
) | ||
return splits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move this into the dataset spec file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is specific to
DataParallelTrainer
, not toDatasetSpec
in general.