3
3
from unittest .mock import patch
4
4
5
5
import pytest
6
- import shutil
7
6
import unittest
8
7
from typing import Optional
9
8
23
22
from ray .train .trainer import BaseTrainer
24
23
from ray .train .xgboost import XGBoostTrainer
25
24
from ray .tune import Callback , CLIReporter
26
- from ray .tune .result import DEFAULT_RESULTS_DIR
27
25
from ray .tune .tune_config import TuneConfig
28
26
from ray .tune .tuner import Tuner
29
27
@@ -106,6 +104,12 @@ def gen_dataset_func_eager():
106
104
class TunerTest (unittest .TestCase ):
107
105
"""The e2e test for hparam tuning using Tuner API."""
108
106
107
+ @pytest .fixture (autouse = True )
108
+ def local_dir (self , tmp_path , monkeypatch ):
109
+ monkeypatch .setenv ("RAY_AIR_LOCAL_CACHE_DIR" , str (tmp_path / "ray_results" ))
110
+ self .local_dir = str (tmp_path / "ray_results" )
111
+ yield self .local_dir
112
+
109
113
def setUp (self ):
110
114
ray .init ()
111
115
@@ -114,9 +118,6 @@ def tearDown(self):
114
118
115
119
def test_tuner_with_xgboost_trainer (self ):
116
120
"""Test a successful run."""
117
- shutil .rmtree (
118
- os .path .join (DEFAULT_RESULTS_DIR , "test_tuner" ), ignore_errors = True
119
- )
120
121
trainer = XGBoostTrainer (
121
122
label_column = "target" ,
122
123
params = {},
@@ -156,10 +157,6 @@ def test_tuner_with_xgboost_trainer(self):
156
157
def test_tuner_with_xgboost_trainer_driver_fail_and_resume (self ):
157
158
# So that we have some global checkpointing happening.
158
159
os .environ ["TUNE_GLOBAL_CHECKPOINT_S" ] = "1"
159
- shutil .rmtree (
160
- os .path .join (DEFAULT_RESULTS_DIR , "test_tuner_driver_fail" ),
161
- ignore_errors = True ,
162
- )
163
160
trainer = XGBoostTrainer (
164
161
label_column = "target" ,
165
162
params = {},
@@ -211,18 +208,16 @@ def on_step_end(self, iteration, trials, **kwargs):
211
208
tuner .fit ()
212
209
213
210
# Test resume
214
- restore_path = os .path .join (DEFAULT_RESULTS_DIR , "test_tuner_driver_fail" )
215
- tuner = Tuner .restore (restore_path , trainable = trainer )
211
+ restore_path = os .path .join (self . local_dir , "test_tuner_driver_fail" )
212
+ tuner = Tuner .restore (restore_path , trainable = trainer , param_space = param_space )
216
213
# A hack before we figure out RunConfig semantics across resumes.
217
214
tuner ._local_tuner ._run_config .callbacks = None
218
215
results = tuner .fit ()
219
216
assert len (results ) == 4
217
+ assert not results .errors
220
218
221
219
def test_tuner_with_torch_trainer (self ):
222
220
"""Test a successful run using torch trainer."""
223
- shutil .rmtree (
224
- os .path .join (DEFAULT_RESULTS_DIR , "test_tuner_torch" ), ignore_errors = True
225
- )
226
221
# The following two should be tunable.
227
222
config = {"lr" : 1e-2 , "hidden_size" : 1 , "batch_size" : 4 , "epochs" : 10 }
228
223
scaling_config = ScalingConfig (num_workers = 1 , use_gpu = False )
@@ -387,6 +382,8 @@ def test_nonserializable_trainable():
387
382
Tuner (lambda config : print (lock ))
388
383
389
384
385
+ # TODO(justinvyu): [chdir_to_trial_dir]
386
+ @pytest .mark .skip ("chdir_to_trial_dir is not implemented yet." )
390
387
@pytest .mark .parametrize ("runtime_env" , [{}, {"working_dir" : "." }])
391
388
def test_tuner_no_chdir_to_trial_dir (shutdown_only , chdir_tmpdir , runtime_env ):
392
389
"""Tests that setting `chdir_to_trial_dir=False` in `TuneConfig` allows for
0 commit comments