diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 56edc5fc7062..2d83fb814408 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -1512,7 +1512,7 @@ def _print_result(self, result: Dict): return print_result_str -def _detect_reporter(**kwargs) -> TuneReporterBase: +def _detect_reporter(_trainer_api: bool = False, **kwargs) -> TuneReporterBase: """Detect progress reporter class. Will return a :class:`JupyterNotebookReporter` if a IPython/Jupyter-like @@ -1520,7 +1520,7 @@ def _detect_reporter(**kwargs) -> TuneReporterBase: Keyword arguments are passed on to the reporter class. """ - if IS_NOTEBOOK: + if IS_NOTEBOOK and not _trainer_api: kwargs.setdefault("overwrite", not has_verbosity(Verbosity.V2_TRIAL_NORM)) progress_reporter = JupyterNotebookReporter(**kwargs) else: diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 2851dc007c14..2bc68ff49a59 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -801,6 +801,9 @@ def testReporterDetection(self): reporter = _detect_reporter() self.assertFalse(isinstance(reporter, CLIReporter)) self.assertTrue(isinstance(reporter, JupyterNotebookReporter)) + trainer_reporter = _detect_reporter(_trainer_api=True) + self.assertFalse(isinstance(trainer_reporter, JupyterNotebookReporter)) + self.assertTrue(isinstance(trainer_reporter, CLIReporter)) def testProgressReporterAPI(self): class CustomReporter(ProgressReporter): diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index bcbfa17eb0b9..299c2c774fe1 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -903,7 +903,10 @@ def run( progress_reporter = None if air_verbosity is None: - progress_reporter = progress_reporter or _detect_reporter() + is_trainer = _entrypoint == AirEntrypoint.TRAINER + progress_reporter = progress_reporter or _detect_reporter( + _trainer_api=is_trainer + ) if resume is not None: resume_config = resume_config or _build_resume_config_from_legacy_config(resume)