Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR/Train] Make Dataset ingest configurable #24066

Merged
merged 18 commits into from
Apr 28, 2022
Merged
39 changes: 24 additions & 15 deletions python/ray/ml/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.train import BackendConfig, TrainingIterator
from ray.train.backend import BackendExecutor
from ray.train.checkpoint import TuneCheckpointManager
from ray.train.utils import construct_train_func
from ray.train.utils import construct_train_func, RayDatasetSpec
from ray.util.annotations import DeveloperAPI

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -292,27 +292,36 @@ def training_loop(self) -> None:
else:
resume_checkpoint_dict = None

# Tell Ray Train to only shard the train dataset and not the other datasets.
# This is purely an implementation detail and users do not need to know about
# this.
# TODO(amog): Refactor this to remove hack and make this more modular.
# TrainingIterator should accept a generic custom_ingest_func that contains
# the logic for how to split the Datasets.
updated_dataset_dict = {}
for key, value in self.datasets.items():
if key == TRAIN_DATASET_KEY:
updated_dataset_dict[key] = value
else:
# Ray Train will strip out the added string before exposing to users.
updated_dataset_dict[key + "_NO-SHARD"] = value
def dataset_split_fn(dataset_dict, training_worker_handles):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pull this out into a default splitting function in the util module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're pulling this out into a default splitting function, could you add a docstring? Would allow readers to understand the function without having to reference _RayDatasetSpec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated into its own function, but left it in data_parallel_trainer for now as it is only being used in DataParallelTrainer. Let's revisit the location if we need it in more trainers in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a type annotation to training_worker_handles? The type wasn't obvious until I read _RayDatasetSpec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

dataset_dict_splits = [{} for _ in range(len(training_worker_handles))]

for key, dataset in dataset_dict.items():
if key == TRAIN_DATASET_KEY:
dataset_splits = dataset.split(
len(training_worker_handles),
equal=True,
locality_hints=training_worker_handles,
)
else:
# Only shard the training dataset.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining why we're only sharding the training dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

dataset_splits = [dataset] * training_worker_handles

for i in range(len(dataset_splits)):
dataset_dict_splits[i][key] = dataset_splits[i]

return dataset_dict_splits

dataset_spec = RayDatasetSpec(
dataset_or_dict=self.datasets, dataset_split_fn=dataset_split_fn
)

# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead
# of just a Dict.
training_iterator = TrainingIterator(
backend_executor=backend_executor,
backend_config=self.backend_config,
train_func=train_loop_per_worker,
dataset=updated_dataset_dict if len(updated_dataset_dict) > 0 else None,
dataset_spec=dataset_spec,
checkpoint_manager=checkpoint_manager,
checkpoint=resume_checkpoint_dict,
checkpoint_strategy=None,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/ml/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def preprocess_datasets(self) -> None:
If the ``Trainer`` has both a datasets dict and
a preprocessor, the datasets dict contains a training dataset (denoted by
the "train" key), and the preprocessor has not yet
been fit, then it will be fit on the train.
been fit, then it will be fit on the train dataset.

Then, the Trainer's datasets will be transformed by the preprocessor.
Then, the all Trainer's datasets will be transformed by the preprocessor.

The transformed datasets will be set back in the ``self.datasets`` attribute
of the Trainer to be used when overriding ``training_loop``.
Expand Down
53 changes: 8 additions & 45 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from collections import defaultdict
from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple
from typing import Callable, TypeVar, List, Optional, Dict, Type, Tuple

import ray
from ray.exceptions import RayActorError
Expand All @@ -14,7 +14,7 @@
)
from ray.train.session import TrainingResult
from ray.train.session import init_session, get_session, shutdown_session
from ray.train.utils import RayDataset, check_for_failure, Singleton
from ray.train.utils import RayDatasetSpec, check_for_failure, Singleton
from ray.train.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
from ray.util.placement_group import get_current_placement_group, remove_placement_group
Expand Down Expand Up @@ -314,42 +314,10 @@ def _create_local_rank_map(self) -> Dict:
ip_dict[node_ip] += 1
return rank_mapping

def _get_dataset_shards(self, dataset_or_dict):

if dataset_or_dict is None:
# Return None for each shard.
return [None] * len(self.worker_group)

def split_dataset(dataset_or_pipeline):
actors = [worker.actor for worker in self.worker_group.workers]
return dataset_or_pipeline.split(
len(self.worker_group), equal=True, locality_hints=actors
)

if isinstance(dataset_or_dict, dict):
# Return a smaller dict for each shard.
dataset_shards = [{} for _ in range(len(self.worker_group))]
# TODO(amog): Update Backend to accept a generic function with logic on
# how to split dataset, instead of having to support _NO-SHARD in key.
for key, dataset in dataset_or_dict.items():
if "_NO-SHARD" in key:
# Do not shard this dataset.
split_datasets = [dataset] * len(self.worker_group)
key = key.replace("_NO-SHARD", "")
else:
split_datasets = split_dataset(dataset)
assert len(split_datasets) == len(self.worker_group)
for i in range(len(split_datasets)):
dataset_shards[i][key] = split_datasets[i]
return dataset_shards
else:
# return a smaller RayDataset for each shard.
return split_dataset(dataset_or_dict)

def start_training(
self,
train_func: Callable[[], T],
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
dataset_spec: RayDatasetSpec = None,
checkpoint: Optional[Dict] = None,
) -> None:
"""Executes a training function on all workers in a separate thread.
Expand All @@ -358,15 +326,9 @@ def start_training(

Args:
train_func (Callable): The training function to run on each worker.
dataset (Optional[Union[Dataset, DatasetPipeline]])
Distributed Ray Dataset or DatasetPipeline to pass into
worker, which can be accessed from the training function via
``train.get_dataset_shard()``. Sharding will automatically be
handled by the Trainer. Multiple Datasets can be passed in as
a ``Dict`` that maps each name key to a Dataset value,
and each Dataset can be accessed from the training function
by passing in a `dataset_name` argument to
``train.get_dataset_shard()``.
dataset_spec (RayDatasetSpec): A specification for the Ray Dataset to be
passed to the training workers, and the logic on how to shard the Ray
Dataset.
checkpoint (Optional[Dict]): The checkpoint data that
should be loaded onto each worker and accessed by the
training function via ``train.load_checkpoint()``. If this
Expand Down Expand Up @@ -406,7 +368,8 @@ def initialize_session(
)

if self.dataset_shards is None:
self.dataset_shards = self._get_dataset_shards(dataset)
actors = [worker.actor for worker in self.worker_group.workers]
self.dataset_shards = dataset_spec.get_dataset_shards(actors)

local_rank_map = self._create_local_rank_map()

Expand Down
29 changes: 20 additions & 9 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
)
from ray.train.callbacks.callback import TrainingCallback
from ray.train.session import TrainingResultType
from ray.train.utils import RayDataset, construct_train_func, ActorWrapper
from ray.train.utils import (
RayDataset,
construct_train_func,
ActorWrapper,
RayDatasetSpec,
)
from ray.train.checkpoint import (
CheckpointStrategy,
TuneCheckpointManager,
Expand Down Expand Up @@ -320,12 +325,14 @@ def run(

train_func = construct_train_func(train_func, config)

dataset_spec = RayDatasetSpec(dataset_or_dict=dataset)

try:
iterator = TrainingIterator(
backend_executor=self._backend_executor,
backend_config=self._backend_config,
train_func=train_func,
dataset=dataset,
dataset_spec=dataset_spec,
checkpoint_manager=self.checkpoint_manager,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
Expand Down Expand Up @@ -397,12 +404,14 @@ def train_func(config):

train_func = construct_train_func(train_func, config)

dataset_spec = RayDatasetSpec(dataset_or_dict=dataset)

return TrainingIterator(
backend_executor=self._backend_executor,
backend_config=self._backend_config,
train_func=train_func,
run_dir=self.latest_run_dir,
dataset=dataset,
dataset_spec=dataset_spec,
checkpoint_manager=self.checkpoint_manager,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
Expand Down Expand Up @@ -634,7 +643,7 @@ def __init__(
backend_executor: Union[BackendExecutor, ActorWrapper],
backend_config: BackendConfig,
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]],
dataset_spec: RayDatasetSpec,
checkpoint_manager: CheckpointManager,
checkpoint: Optional[Union[Dict, str, Path]],
checkpoint_strategy: Optional[CheckpointStrategy],
Expand All @@ -643,14 +652,14 @@ def __init__(
self._backend_executor = backend_executor
self._backend = backend_config.backend_cls()
self._train_func = train_func
self._dataset = dataset
self._dataset_spec = dataset_spec
self._run_dir = run_dir
self._checkpoint_manager = checkpoint_manager
self._checkpoint_strategy = checkpoint_strategy
self._start_training(
train_func=train_func,
run_dir=run_dir,
dataset=dataset,
dataset_spec=self._dataset_spec,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
)
Expand All @@ -665,7 +674,7 @@ def _start_training(
self,
train_func,
run_dir,
dataset,
dataset_spec,
checkpoint,
checkpoint_strategy,
latest_checkpoint_id=None,
Expand All @@ -678,7 +687,9 @@ def _start_training(
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
self._run_with_error_handling(
lambda: self._backend_executor.start_training(
train_func=train_func, dataset=dataset, checkpoint=checkpoint_dict
train_func=train_func,
dataset_config=dataset_spec,
checkpoint=checkpoint_dict,
)
)

Expand All @@ -697,7 +708,7 @@ def _run_with_error_handling(self, func: Callable):
self._start_training(
self._train_func,
self._run_dir,
self._dataset,
self._dataset_spec,
self._checkpoint_manager.latest_checkpoint,
self._checkpoint_strategy,
latest_checkpoint_id=self._checkpoint_manager.latest_checkpoint_id,
Expand Down
86 changes: 86 additions & 0 deletions python/ray/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from dataclasses import dataclass
import inspect
import os
import logging
Expand Down Expand Up @@ -172,3 +173,88 @@ def __getattr__(self, item):
# actor.
actor_method = getattr(self.actor, item)
return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs))


@dataclass
class RayDatasetSpec:
amogkam marked this conversation as resolved.
Show resolved Hide resolved
"""Configuration for Ray Datasets to pass to the training workers.

dataset_or_dict: An optional Ray Dataset (or DatasetPipeline) or a dictionary of
datasets to be sharded across all the training workers, which can be accessed
from the training function via ``train.get_dataset_shard()``. Multiple Datasets
can be passed in as a ``Dict`` that maps each name key to a Dataset value,
amogkam marked this conversation as resolved.
Show resolved Hide resolved
and each Dataset can be accessed from the training function by passing in a
`dataset_name` argument to ``train.get_dataset_shard()``.
dataset_split_fn: An optional callbale to specify how the provided ``dataset``
amogkam marked this conversation as resolved.
Show resolved Hide resolved
should be split across the training workers. It is expected to take in two
arguments. The first one is the ``dataset``, just as is passed in to the
``RayDatasetSpec``. The second argument is a list of the ActorHandles of the
training workers (to use as locality hints). The Callable is expected to
return a list of RayDatasets or a list of dictionaries of RayDatasets,
with the length of the list equal to the length of the list of actor handles.
If None is provided, the provided Ray Dataset(s) will be simply be split using
the actor handles as locality hints.

"""

dataset_or_dict: Optional[Union[RayDataset, Dict[str, RayDataset]]]
dataset_split_fn: Optional[
Callable[
[Union[RayDataset, Dict[str, RayDataset]], List[ActorHandle]],
List[Union[RayDataset, Dict[str, RayDataset]]],
]
] = None

def _default_split_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need both _RayDataSpec._default_split_fn and dataset_split_fn in training_loop. Isn't detaset_split_fn the default for training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset_split_fn is the implementation that DataParallelTrainer uses, but is not the default for RayDatasetSpec in general.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More specifically, the default implementation for RayDatasetSpec is to split all datasets.

DataParallelTrainer is overriding this behavior to split just the train dataset, but not split the other datasets.

In the future, users should be able to override the behavior for DataParallelTrainer.

self, training_worker_handles: List[ActorHandle]
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]:
def split_dataset(dataset_or_pipeline):
return dataset_or_pipeline.split(
len(training_worker_handles),
equal=True,
locality_hints=training_worker_handles,
)

if isinstance(self.dataset_or_dict, dict):
# Return a smaller dict for each shard.
dataset_shards = [{} for _ in range(len(self.worker_group))]
for key, dataset in self.dataset_or_dict.items():
split_datasets = split_dataset(dataset)
assert len(split_datasets) == len(self.worker_group)
for i in range(len(split_datasets)):
dataset_shards[i][key] = split_datasets[i]
return dataset_shards
else:
# return a smaller RayDataset for each shard.
return split_dataset(self.dataset_or_dict)

def get_dataset_shards(
self, training_worker_handles: List[ActorHandle]
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]:
"""Returns Dataset splits based off the spec and the given training workers

Args:
training_worker_handles: A list of the training worker actor handles.

Returns:
A list of RayDataset shards or list of dictionaries of RayDataset shards,
one for each training worker.

"""
if self.dataset_or_dict is None:
# If no Dataset is provided, return None for each shard.
return [None] * len(training_worker_handles)

if self.dataset_split_fn is None:
return self._default_split_fn(training_worker_handles)
else:
splits = self.dataset_split_fn(
self.dataset_or_dict, training_worker_handles
)
if not len(splits) == len(training_worker_handles):
raise RuntimeError(
"The list of Datasets returned by the "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving this class into a separate file, such as ml/train/impl/dataset_spec.py?

(also, this should go into ml/train for now right? given that ray/train is deprecated).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on moving dataset spec to it's own module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to its own module!

But kept it as part of ray/train. It's being used by current Ray Train, and as discussed offline, the end state is to eventually move ray/ml/train to ray/train anyways.

f"`dataset_split_fn`: {len(splits)} does not match "
f"the number of training workers: {len(training_worker_handles)}"
)
return splits