Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import logging
import tempfile
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type

from ray import tune
from ray import train, tune
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.checkpoint import Checkpoint
from ray.air.checkpoint import Checkpoint as LegacyCheckpoint
from ray.train._checkpoint import Checkpoint
from ray.air.config import RunConfig, ScalingConfig
from ray.train._internal.storage import _use_storage_context
from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.tune import Trainable
Expand Down Expand Up @@ -145,9 +148,11 @@ class GBDTTrainer(BaseTrainer):

_dmatrix_cls: type
_ray_params_cls: type
_tune_callback_report_cls: type
_tune_callback_checkpoint_cls: type
_default_ray_params: Dict[str, Any] = {"checkpoint_frequency": 1}
_default_ray_params: Dict[str, Any] = {
"checkpoint_frequency": 1,
"checkpoint_at_end": True,
}
_init_model_arg_name: str
_num_iterations_argument: str = "num_boost_round"
_default_num_iterations: int = _DEFAULT_NUM_ITERATIONS
Expand All @@ -163,7 +168,7 @@ def __init__(
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
preprocessor: Optional["Preprocessor"] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
resume_from_checkpoint: Optional[LegacyCheckpoint] = None,
metadata: Optional[Dict[str, Any]] = None,
**train_kwargs,
):
Expand Down Expand Up @@ -219,7 +224,7 @@ def _get_dmatrices(

def _load_checkpoint(
self,
checkpoint: Checkpoint,
checkpoint: LegacyCheckpoint,
) -> Tuple[Any, Optional["Preprocessor"]]:
raise NotImplementedError

Expand Down Expand Up @@ -269,9 +274,21 @@ def _checkpoint_at_end(self, model, evals_result: dict) -> None:
for k in list(result_dict):
result_dict[k] = result_dict[k][-1]

with tune.checkpoint_dir(step=self._model_iteration(model)) as cp_dir:
self._save_model(model, path=os.path.join(cp_dir, MODEL_KEY))
tune.report(**result_dict)
if getattr(self._tune_callback_checkpoint_cls, "_report_callbacks_cls", None):
# Deprecate: Remove in Ray 2.8
with tune.checkpoint_dir(step=self._model_iteration(model)) as cp_dir:
self._save_model(model, path=os.path.join(cp_dir, MODEL_KEY))
tune.report(**result_dict)
else:
with tempfile.TemporaryDirectory() as checkpoint_dir:
self._save_model(model, path=os.path.join(checkpoint_dir, MODEL_KEY))

if _use_storage_context():
checkpoint = Checkpoint.from_directory(checkpoint_dir)
else:
checkpoint = LegacyCheckpoint.from_directory(checkpoint_dir)

train.report(result_dict, checkpoint=checkpoint)

def training_loop(self) -> None:
config = self.train_kwargs.copy()
Expand All @@ -291,21 +308,16 @@ def training_loop(self) -> None:
config.setdefault("callbacks", [])

if not any(
isinstance(
cb, (self._tune_callback_report_cls, self._tune_callback_checkpoint_cls)
)
isinstance(cb, self._tune_callback_checkpoint_cls)
for cb in config["callbacks"]
):
# Only add our own callback if it hasn't been added before
checkpoint_frequency = (
self.run_config.checkpoint_config.checkpoint_frequency
)
if checkpoint_frequency > 0:
callback = self._tune_callback_checkpoint_cls(
filename=MODEL_KEY, frequency=checkpoint_frequency
)
else:
callback = self._tune_callback_report_cls()
callback = self._tune_callback_checkpoint_cls(
filename=MODEL_KEY, frequency=checkpoint_frequency
)

config["callbacks"] += [callback]

Expand Down
3 changes: 1 addition & 2 deletions python/ray/train/lightgbm/lightgbm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import lightgbm
import lightgbm_ray
import xgboost_ray
from lightgbm_ray.tune import TuneReportCheckpointCallback, TuneReportCallback
from lightgbm_ray.tune import TuneReportCheckpointCallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
Expand Down Expand Up @@ -95,7 +95,6 @@ class LightGBMTrainer(GBDTTrainer):
# but it is explicitly set here for forward compatibility
_dmatrix_cls: type = lightgbm_ray.RayDMatrix
_ray_params_cls: type = lightgbm_ray.RayParams
_tune_callback_report_cls: type = TuneReportCallback
_tune_callback_checkpoint_cls: type = TuneReportCheckpointCallback
_default_ray_params: Dict[str, Any] = {
"checkpoint_frequency": 1,
Expand Down
3 changes: 1 addition & 2 deletions python/ray/train/xgboost/xgboost_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import xgboost
import xgboost_ray
from xgboost_ray.tune import TuneReportCheckpointCallback, TuneReportCallback
from xgboost_ray.tune import TuneReportCheckpointCallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
Expand Down Expand Up @@ -88,7 +88,6 @@ class XGBoostTrainer(GBDTTrainer):

_dmatrix_cls: type = xgboost_ray.RayDMatrix
_ray_params_cls: type = xgboost_ray.RayParams
_tune_callback_report_cls: type = TuneReportCallback
_tune_callback_checkpoint_cls: type = TuneReportCheckpointCallback
_default_ray_params: Dict[str, Any] = {
"num_actors": 1,
Expand Down
8 changes: 3 additions & 5 deletions python/ray/tune/examples/lightgbm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.lightgbm import (
TuneReportCheckpointCallback,
TuneReportCallback,
)
from ray.tune.integration.lightgbm import TuneReportCheckpointCallback


def train_breast_cancer(config: dict):
Expand Down Expand Up @@ -62,13 +59,14 @@ def train_breast_cancer_cv(config: dict):
# with the cv_agg key. Both mean and standard deviation
# are provided.
callbacks=[
TuneReportCallback(
TuneReportCheckpointCallback(
{
"binary_error": "cv_agg-binary_error-mean",
"binary_logloss": "cv_agg-binary_logloss-mean",
"binary_error_stdv": "cv_agg-binary_error-stdv",
"binary_logloss_stdv": "cv_agg-binary_logloss-stdv",
},
frequency=0,
)
],
)
Expand Down
13 changes: 7 additions & 6 deletions python/ray/tune/examples/xgboost_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.xgboost import (
TuneReportCheckpointCallback,
TuneReportCallback,
)
from ray.tune.integration.xgboost import TuneReportCheckpointCallback


def train_breast_cancer(config: dict):
Expand All @@ -33,7 +30,7 @@ def train_breast_cancer(config: dict):
train_set,
evals=[(test_set, "test")],
verbose_eval=False,
callbacks=[TuneReportCheckpointCallback(filename="model.xgb")],
callbacks=[TuneReportCheckpointCallback(filename="model.xgb", frequency=1)],
)


Expand All @@ -57,7 +54,11 @@ def average_cv_folds(results_dict: Dict[str, List[float]]) -> Dict[str, float]:
verbose_eval=False,
stratified=True,
# Checkpointing is not supported for CV
callbacks=[TuneReportCallback(results_postprocessing_fn=average_cv_folds)],
callbacks=[
TuneReportCheckpointCallback(
results_postprocessing_fn=average_cv_folds, frequency=0
)
],
)


Expand Down
Loading