Skip to content

Commit

Permalink
[1/n] Lightweight Ray AIR API refactor (#36706)
Browse files Browse the repository at this point in the history
This PR removes some circularities in the Ray AIR import system so we can put the training related functions into `ray.train`. It introduces a training context and makes report, get_dataset_shard, Checkpoint, Result, and the following configs:

- CheckpointConfig
- DataConfig
- FailureConfig
- RunConfig
- ScalingConfig

available in `ray.train`. No user facing changes yet, the old APIs still work.

Going forward, it will be most consistent / symmetrical if these things are included in the following way:

```python
from ray import train, tune, serve # Pick the subset that is needed
# Include what you need from the following:
from ray.train import CheckpointConfig, DataConfig, FailureConfig, RunConfig, ScalingConfig

# ...

def train_func():
    dataset_shard = train.get_dataset_shard("train")
    world_size = train.get_context().get_world_size()
    # ...
    train.report(...)

trainer = train.torch.TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
```

We have many examples in #37123 on how this looks like in actual code.
  • Loading branch information
pcmoritz authored Jul 8, 2023
1 parent a7852a4 commit dd029b6
Show file tree
Hide file tree
Showing 11 changed files with 529 additions and 417 deletions.
19 changes: 19 additions & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ Train Backend Base Classes
~train.backend.Backend
~train.backend.BackendConfig

Ray Train Config
----------------

.. autosummary::

~ray.train.DataConfig


Ray Train Loop
--------------

.. autosummary::
:toctree: doc/

~train.context.TrainContext
~train.get_context
~train.get_dataset_shard
~train.report


.. _train-integration-api:
.. _train-framework-specific-ckpts:
Expand Down
10 changes: 5 additions & 5 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def from_bytes(cls, data: bytes) -> "Checkpoint":
data: Data object containing pickled checkpoint data.
Returns:
Checkpoint: checkpoint object.
ray.air.checkpoint.Checkpoint: checkpoint object.
"""
bytes_data = pickle.loads(data)
if isinstance(bytes_data, dict):
Expand Down Expand Up @@ -360,7 +360,7 @@ def from_dict(cls, data: dict) -> "Checkpoint":
data: Dictionary containing checkpoint data.
Returns:
Checkpoint: checkpoint object.
ray.air.checkpoint.Checkpoint: checkpoint object.
"""
state = {}
if _METADATA_KEY in data:
Expand Down Expand Up @@ -455,7 +455,7 @@ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
Checkpoint).
Returns:
Checkpoint: checkpoint object.
ray.air.checkpoint.Checkpoint: checkpoint object.
"""
state = {}

Expand All @@ -474,7 +474,7 @@ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
@classmethod
@DeveloperAPI
def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
"""Create a checkpoint from a generic :class:`Checkpoint`.
"""Create a checkpoint from a generic :class:`ray.air.checkpoint.Checkpoint`.
This method can be used to create a framework-specific checkpoint from a
generic :class:`Checkpoint` object.
Expand Down Expand Up @@ -715,7 +715,7 @@ def from_uri(cls, uri: str) -> "Checkpoint":
uri: Source location URI to read data from.
Returns:
Checkpoint: checkpoint object.
ray.air.checkpoint.Checkpoint: checkpoint object.
"""
state = {}
try:
Expand Down
Loading

0 comments on commit dd029b6

Please sign in to comment.