Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[1/n] Lightweight Ray AIR API refactor (#36706)
This PR removes some circularities in the Ray AIR import system so we can put the training related functions into `ray.train`. It introduces a training context and makes report, get_dataset_shard, Checkpoint, Result, and the following configs: - CheckpointConfig - DataConfig - FailureConfig - RunConfig - ScalingConfig available in `ray.train`. No user facing changes yet, the old APIs still work. Going forward, it will be most consistent / symmetrical if these things are included in the following way: ```python from ray import train, tune, serve # Pick the subset that is needed # Include what you need from the following: from ray.train import CheckpointConfig, DataConfig, FailureConfig, RunConfig, ScalingConfig # ... def train_func(): dataset_shard = train.get_dataset_shard("train") world_size = train.get_context().get_world_size() # ... train.report(...) trainer = train.torch.TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=2), ) result = trainer.fit() ``` We have many examples in #37123 on how this looks like in actual code.
- Loading branch information