diff --git a/smartsim/_core/control/job.py b/smartsim/_core/control/job.py index 924a043de..76ab43d6f 100644 --- a/smartsim/_core/control/job.py +++ b/smartsim/_core/control/job.py @@ -54,6 +54,7 @@ def __init__(self) -> None: self.telemetry_on: bool = False self.collectors: t.Dict[str, str] = {} self.config: t.Dict[str, str] = {} + self._is_complete: bool = False @property def is_db(self) -> bool: @@ -67,6 +68,13 @@ def is_managed(self) -> bool: def key(self) -> _JobKey: return _JobKey(self.step_id, self.task_id) + @property + def is_complete(self) -> bool: + return self._is_complete + + def set_complete(self) -> None: + self._is_complete = True + class Job: """Keep track of various information for the controller. diff --git a/smartsim/_core/entrypoints/telemetrymonitor.py b/smartsim/_core/entrypoints/telemetrymonitor.py index 96cace6c1..f4e30e06f 100644 --- a/smartsim/_core/entrypoints/telemetrymonitor.py +++ b/smartsim/_core/entrypoints/telemetrymonitor.py @@ -43,6 +43,7 @@ from types import FrameType import redis.asyncio as redis +import redis.exceptions as redisex from anyio import open_file, sleep from watchdog.events import ( FileCreatedEvent, @@ -178,7 +179,7 @@ async def shutdown(self) -> None: """Execute any cleanup of resources for the collector""" -class CollectorStopHandler(PatternMatchingEventHandler): +class TaskStatusHandler(PatternMatchingEventHandler): """A file listener that will notify a set of collectors when an unmanaged entity has completed""" @@ -224,6 +225,7 @@ def _notify(self, event_src: str) -> None: for col in working_set: logger.debug(f"Disabling {col.entity.name}::{type(col).__name__}") col.disable() + col.entity.set_complete() @dataclasses.dataclass @@ -303,6 +305,18 @@ async def shutdown(self) -> None: except Exception as ex: logger.error("An error occurred during DbCollector shutdown", exc_info=ex) + async def _check_db(self) -> bool: + """Check if a database is reachable. + + :returns: True if connection succeeds, False otherwise.""" + try: + if self._client: + return await self._client.ping() + except redisex.ConnectionError: + logger.info(f"Cannot ping db {self._address}") + + return False + class DbMemoryCollector(DbCollector): """A collector that collects memory usage information from @@ -334,6 +348,9 @@ async def collect(self) -> None: self._value = {} try: + if not await self._check_db(): + return + db_info = await self._client.info("memory") for key in self._columns: self._value[key] = db_info[key] @@ -370,6 +387,9 @@ async def collect(self) -> None: now_ts = self.timestamp() # ensure all results have the same timestamp try: + if not await self._check_db(): + return + clients = await self._client.client_list() self._value = [{"addr": item["addr"], "id": item["id"]} for item in clients] @@ -408,6 +428,9 @@ async def collect(self) -> None: return try: + if not await self._check_db(): + return + client_list = await self._client.client_list() now_ts = self.timestamp() # ensure all results have the same timestamp @@ -437,7 +460,7 @@ def __init__(self, timeout_ms: int = 1000) -> None: self._timeout_ms = timeout_ms self._tasks: t.List[asyncio.Task[None]] = [] self._stoppers: t.Dict[ - str, t.List[CollectorStopHandler] + str, t.List[TaskStatusHandler] ] = collections.defaultdict(lambda: []) def clear(self) -> None: @@ -457,7 +480,7 @@ def _create_stop_listener(self, collector: Collector) -> None: stopper.add(collector) return - stopper = CollectorStopHandler(collector, patterns=["stop.json"]) + stopper = TaskStatusHandler(collector, patterns=["stop.json"]) observer = Observer() observer.schedule(stopper, collector.entity.status_dir) # type: ignore observer.start() # type: ignore @@ -835,6 +858,10 @@ def __init__( def timeout_ms(self) -> int: return self._timeout_ms + @property + def tracked_jobs(self) -> t.Iterable[JobEntity]: + return self._tracked_jobs.values() + def init_launcher(self, launcher: str) -> Launcher: """Initialize the controller with a specific type of launcher. SmartSim currently supports slurm, pbs(pro), lsf, @@ -1012,16 +1039,15 @@ async def shutdown(self) -> None: def can_shutdown(action_handler: ManifestEventHandler) -> bool: - jobs = action_handler.job_manager.jobs - db_jobs = action_handler.job_manager.db_jobs + jobs = action_handler.job_manager.jobs # managed jobs from job manager + all_jobs = action_handler.tracked_jobs # unmanaged jobs tracked locally + db_jobs = list(filter(lambda j: j.is_db and not j.is_complete, all_jobs)) - if has_jobs := bool(jobs): - logger.debug(f"telemetry monitor is monitoring {len(jobs)} jobs") - if has_dbs := bool(db_jobs): - logger.debug(f"telemetry monitor is monitoring {len(db_jobs)} dbs") + n_jobs, n_dbs = len(jobs), len(db_jobs) + shutdown_ok = n_jobs + n_dbs == 0 - has_running_jobs = has_jobs or has_dbs - return not has_running_jobs + logger.debug(f"{n_jobs} active job(s), {n_dbs} active db(s)") + return shutdown_ok async def event_loop(