Skip to content

Commit 57dcd30

Browse files
andrewztanrichardliaw
authored andcommitted
[tune] Trial reporter fix (#3951)
Fixes #3949.
1 parent 3a7fb18 commit 57dcd30

File tree

4 files changed

+35
-7
lines changed

4 files changed

+35
-7
lines changed

python/ray/tune/function_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ray.tune import TuneError
1010
from ray.tune.trainable import Trainable
11-
from ray.tune.result import TIMESTEPS_TOTAL
11+
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -28,6 +28,7 @@ def __init__(self):
2828
self._lock = threading.Lock()
2929
self._error = None
3030
self._done = False
31+
self._iteration = 0
3132

3233
def __call__(self, **kwargs):
3334
"""Report updated training status.
@@ -44,6 +45,7 @@ def __call__(self, **kwargs):
4445

4546
with self._lock:
4647
self._latest_result = self._last_result = kwargs.copy()
48+
self._iteration += 1
4749

4850
def _get_and_clear_status(self):
4951
if self._error:
@@ -55,10 +57,13 @@ def _get_and_clear_status(self):
5557
"last result. To avoid this, include done=True "
5658
"upon the last reporter call.")
5759
self._last_result.update(done=True)
60+
self._last_result.setdefault(TRAINING_ITERATION, self._iteration)
5861
return self._last_result
5962
with self._lock:
6063
res = self._latest_result
6164
self._latest_result = None
65+
if res:
66+
res.setdefault(TRAINING_ITERATION, self._iteration)
6267
return res
6368

6469
def _stop(self):

python/ray/tune/suggest/bayesopt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def on_trial_complete(self,
9696
self.optimizer.register(
9797
params=self._live_trial_mapping[trial_id],
9898
target=result[self._reward_attr])
99-
del self._live_trial_mapping[trial_id]
99+
100+
del self._live_trial_mapping[trial_id]
100101

101102
def _num_live_trials(self):
102103
return len(self._live_trial_mapping)

python/ray/tune/test/trial_runner_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
2020
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
2121
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
22-
EPISODES_TOTAL)
22+
EPISODES_TOTAL, TRAINING_ITERATION)
2323
from ray.tune.logger import Logger
2424
from ray.tune.util import pin_in_object_store, get_pinned_object
2525
from ray.tune.experiment import Experiment
@@ -560,6 +560,28 @@ def _restore(self, state):
560560
self.assertEqual(trial.status, Trial.TERMINATED)
561561
self.assertTrue(trial.has_checkpoint())
562562

563+
def testIterationCounter(self):
564+
def train(config, reporter):
565+
for i in range(100):
566+
reporter(itr=i, done=i == 99)
567+
568+
register_trainable("exp", train)
569+
config = {
570+
"my_exp": {
571+
"run": "exp",
572+
"config": {
573+
"iterations": 100,
574+
},
575+
"stop": {
576+
"timesteps_total": 100
577+
},
578+
}
579+
}
580+
[trial] = run_experiments(config)
581+
self.assertEqual(trial.status, Trial.TERMINATED)
582+
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
583+
self.assertEqual(trial.last_result["itr"], 99)
584+
563585

564586
class RunExperimentTest(unittest.TestCase):
565587
def setUp(self):

python/ray/tune/trainable.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import ray
2020
from ray.tune.logger import UnifiedLogger
21-
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
22-
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
23-
EPISODES_THIS_ITER, EPISODES_TOTAL)
21+
from ray.tune.result import (
22+
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
23+
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION)
2424
from ray.tune.trial import Resources
2525

2626
logger = logging.getLogger(__name__)
@@ -181,6 +181,7 @@ def train(self):
181181
# self._timesteps_total should not override user-provided total
182182
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
183183
result.setdefault(EPISODES_TOTAL, self._episodes_total)
184+
result.setdefault(TRAINING_ITERATION, self._iteration)
184185

185186
# Provides auto-filled neg_mean_loss for avoiding regressions
186187
if result.get("mean_loss"):
@@ -191,7 +192,6 @@ def train(self):
191192
experiment_id=self._experiment_id,
192193
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
193194
timestamp=int(time.mktime(now.timetuple())),
194-
training_iteration=self._iteration,
195195
time_this_iter_s=time_this_iter,
196196
time_total_s=self._time_total,
197197
pid=os.getpid(),

0 commit comments

Comments
 (0)