Skip to content

Commit

Permalink
solve early shutdown bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Feb 22, 2024
1 parent 4e71660 commit 1572d11
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
8 changes: 8 additions & 0 deletions smartsim/_core/control/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self) -> None:
self.telemetry_on: bool = False
self.collectors: t.Dict[str, str] = {}
self.config: t.Dict[str, str] = {}
self._is_complete: bool = False

@property
def is_db(self) -> bool:
Expand All @@ -67,6 +68,13 @@ def is_managed(self) -> bool:
def key(self) -> _JobKey:
return _JobKey(self.step_id, self.task_id)

@property
def is_complete(self) -> bool:
return self._is_complete

def set_complete(self) -> None:
self._is_complete = True


class Job:
"""Keep track of various information for the controller.
Expand Down
48 changes: 37 additions & 11 deletions smartsim/_core/entrypoints/telemetrymonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from types import FrameType

import redis.asyncio as redis
import redis.exceptions as redisex
from anyio import open_file, sleep
from watchdog.events import (
FileCreatedEvent,
Expand Down Expand Up @@ -178,7 +179,7 @@ async def shutdown(self) -> None:
"""Execute any cleanup of resources for the collector"""


class CollectorStopHandler(PatternMatchingEventHandler):
class TaskStatusHandler(PatternMatchingEventHandler):
"""A file listener that will notify a set of collectors when an
unmanaged entity has completed"""

Expand Down Expand Up @@ -224,6 +225,7 @@ def _notify(self, event_src: str) -> None:
for col in working_set:
logger.debug(f"Disabling {col.entity.name}::{type(col).__name__}")
col.disable()
col.entity.set_complete()


@dataclasses.dataclass
Expand Down Expand Up @@ -303,6 +305,18 @@ async def shutdown(self) -> None:
except Exception as ex:
logger.error("An error occurred during DbCollector shutdown", exc_info=ex)

async def _check_db(self) -> bool:
"""Check if a database is reachable.
:returns: True if connection succeeds, False otherwise."""
try:
if self._client:
return await self._client.ping()
except redisex.ConnectionError:
logger.info(f"Cannot ping db {self._address}")

return False


class DbMemoryCollector(DbCollector):
"""A collector that collects memory usage information from
Expand Down Expand Up @@ -334,6 +348,9 @@ async def collect(self) -> None:
self._value = {}

try:
if not await self._check_db():
return

db_info = await self._client.info("memory")
for key in self._columns:
self._value[key] = db_info[key]
Expand Down Expand Up @@ -370,6 +387,9 @@ async def collect(self) -> None:
now_ts = self.timestamp() # ensure all results have the same timestamp

try:
if not await self._check_db():
return

clients = await self._client.client_list()

self._value = [{"addr": item["addr"], "id": item["id"]} for item in clients]
Expand Down Expand Up @@ -408,6 +428,9 @@ async def collect(self) -> None:
return

try:
if not await self._check_db():
return

client_list = await self._client.client_list()

now_ts = self.timestamp() # ensure all results have the same timestamp
Expand Down Expand Up @@ -437,7 +460,7 @@ def __init__(self, timeout_ms: int = 1000) -> None:
self._timeout_ms = timeout_ms
self._tasks: t.List[asyncio.Task[None]] = []
self._stoppers: t.Dict[
str, t.List[CollectorStopHandler]
str, t.List[TaskStatusHandler]
] = collections.defaultdict(lambda: [])

def clear(self) -> None:
Expand All @@ -457,7 +480,7 @@ def _create_stop_listener(self, collector: Collector) -> None:
stopper.add(collector)
return

stopper = CollectorStopHandler(collector, patterns=["stop.json"])
stopper = TaskStatusHandler(collector, patterns=["stop.json"])
observer = Observer()
observer.schedule(stopper, collector.entity.status_dir) # type: ignore
observer.start() # type: ignore
Expand Down Expand Up @@ -835,6 +858,10 @@ def __init__(
def timeout_ms(self) -> int:
return self._timeout_ms

@property
def tracked_jobs(self) -> t.Iterable[JobEntity]:
return self._tracked_jobs.values()

def init_launcher(self, launcher: str) -> Launcher:
"""Initialize the controller with a specific type of launcher.
SmartSim currently supports slurm, pbs(pro), lsf,
Expand Down Expand Up @@ -1012,16 +1039,15 @@ async def shutdown(self) -> None:


def can_shutdown(action_handler: ManifestEventHandler) -> bool:
jobs = action_handler.job_manager.jobs
db_jobs = action_handler.job_manager.db_jobs
jobs = action_handler.job_manager.jobs # managed jobs from job manager
all_jobs = action_handler.tracked_jobs # unmanaged jobs tracked locally
db_jobs = list(filter(lambda j: j.is_db and not j.is_complete, all_jobs))

if has_jobs := bool(jobs):
logger.debug(f"telemetry monitor is monitoring {len(jobs)} jobs")
if has_dbs := bool(db_jobs):
logger.debug(f"telemetry monitor is monitoring {len(db_jobs)} dbs")
n_jobs, n_dbs = len(jobs), len(db_jobs)
shutdown_ok = n_jobs + n_dbs == 0

has_running_jobs = has_jobs or has_dbs
return not has_running_jobs
logger.debug(f"{n_jobs} active job(s), {n_dbs} active db(s)")
return shutdown_ok


async def event_loop(
Expand Down

0 comments on commit 1572d11

Please sign in to comment.