Skip to content

Commit

Permalink
tweak sink/collector api, log addresses in errors, add MockSink to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Feb 22, 2024
1 parent 655b0c2 commit 640d8b5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 39 deletions.
52 changes: 36 additions & 16 deletions smartsim/_core/entrypoints/telemetrymonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import argparse
import asyncio
import collections
import dataclasses
import datetime
import itertools
import json
Expand Down Expand Up @@ -79,13 +80,15 @@

class Sink(abc.ABC):
"""Base class for telemetry output sinks"""

@abc.abstractmethod
def save(self, **kwargs: t.Any) -> None:
...


class FileSink(Sink):
"""Telemetry sink that writes to a file"""

def __init__(self, entity: JobEntity, sub: str) -> None:
# todo: consider renaming sub (it's the sub-path under entity.status_dir)
# todo: consider specifying sub & file name separately?
Expand All @@ -101,6 +104,7 @@ def save(self, **kwargs: t.Any) -> None:

class LogSink(Sink):
"""Telemetry sink that writes console output for testing purposes"""

def save(self, **kwargs: t.Any) -> None:
"""Save all arguments as console logged messages"""
logger.info(",".join(map(str, kwargs.values())))
Expand All @@ -109,18 +113,23 @@ def save(self, **kwargs: t.Any) -> None:
class Collector(abc.ABC):
"""Base class for metrics collectors"""

def __init__(self, entity: JobEntity) -> None:
def __init__(self, entity: JobEntity, sink: Sink) -> None:
"""Initialize the collector
:param entity: The entity to collect metrics on
:type entity: JobEntity"""
self._entity = entity
self._sink = sink
self._value: t.Any = None

@property
def owner(self) -> str:
return self._entity.name

@property
def sink(self) -> Sink:
return self._sink

@abc.abstractmethod
async def prepare(self) -> None:
"""Initialization logic for a collector"""
Expand All @@ -134,32 +143,45 @@ def timestamp() -> int:
return int(datetime.datetime.timestamp(datetime.datetime.now()))


@dataclasses.dataclass
class _Address:
"""Helper class to hold and pretty-print connection details"""

host: str
port: int

def __str__(self) -> str:
return f"{self.host}:{self.port}"


class DbCollector(Collector):
"""A base class for collectors that retrieve statistics from an orchestrator"""

def __init__(self, entity: JobEntity, sink: Sink) -> None:
"""Initialize the collector"""
super().__init__(entity)
super().__init__(entity, sink)
self._client: t.Optional[redis.Redis[bytes]] = None
self._sink = sink
self._address = _Address(
self._entity.meta.get("host", "127.0.0.1"),
int(self._entity.meta.get("port", 6379)),
)

async def _configure_client(self) -> None:
"""Configure and connect to the target database"""
try:
db_host = self._entity.meta.get("host", "localhost")
db_port = self._entity.meta.get("port", "6379")

if not self._client:
self._client = redis.Redis(host=db_host, port=int(db_port))
self._client = redis.Redis(
host=self._address.host, port=self._address.port
)

except Exception as e:
logger.exception(e)
raise SmartSimError(
"Collector failed to communicate with metric producer"
) from e
msg = f"DbCollector failed to communicate with {self._address}"
raise SmartSimError(msg) from e

if not self._client: # or not self._client.is_connected:
raise SmartSimError("Collector failed to connect to metric producer")
msg = f"DbCollector failed to connect to {self._address}"
raise SmartSimError(msg)

async def prepare(self) -> None:
"""Initialization logic for a DB collector"""
Expand All @@ -176,7 +198,7 @@ class DbMemoryCollector(DbCollector):
async def collect(self) -> None:
await self.prepare()
if not self._client:
logger.warning("DbMemoryCollector is not connected and cannot collect")
logger.warning("DbMemoryCollector cannot collect")
return

db_info = await self._client.info()
Expand Down Expand Up @@ -257,9 +279,8 @@ async def collect(self) -> None:
logger.debug("Executing all telemetry collectors")

if collectors := self.all_collectors:
results = await asyncio.wait(
list(collector.collect() for collector in collectors)
)
tasks = [collector.collect() for collector in collectors]
results = await asyncio.gather(*tasks)
print(f"collector.collect() results:\n{results}")

@classmethod
Expand Down Expand Up @@ -707,7 +728,6 @@ async def on_timestep(self, timestamp: int) -> None:
names = {entity.name: entity for entity in entity_map.values()}

# trigger all metric collection for the timestep
# asyncio.run(self._collector.collect())
await self._collector.collect()

if names:
Expand Down
56 changes: 33 additions & 23 deletions tests/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,23 @@
DbConnectionCollector,
DbMemoryCollector,
JobEntity,
LogSink,
redis,
Sink,
)
from smartsim.error import SmartSimError

# The tests in this file belong to the slow_tests group
pytestmark = pytest.mark.group_a


class MockSink(Sink):
"""Telemetry sink that writes console output for testing purposes"""
def save(self, **kwargs: t.Any) -> None:
"""Save all arguments as console logged messages"""
print(f"MockSink received args: {kwargs}")
self.args = kwargs


@pytest.fixture
def mock_redis():
def _mock_redis(
Expand Down Expand Up @@ -83,14 +91,14 @@ async def test_dbmemcollector_prepare(mock_entity):
"""Ensure that collector preparation succeeds when expected"""
entity = mock_entity()

collector = DbMemoryCollector(entity, LogSink())
collector = DbMemoryCollector(entity, MockSink())
await collector.prepare()
assert collector._client


@pytest.mark.asyncio
async def test_dbmemcollector_prepare_fail(
mock_entity, mock_redis, monkeypatch: pytest.MonkeyPatch
mock_entity, monkeypatch: pytest.MonkeyPatch
):
"""Ensure that collector preparation reports a failure to connect"""
entity = mock_entity()
Expand All @@ -100,13 +108,13 @@ async def test_dbmemcollector_prepare_fail(
ctx.setattr(redis, "Redis", lambda host, port: None)

with pytest.raises(SmartSimError) as ex:
collector = DbMemoryCollector(entity, LogSink())
collector = DbMemoryCollector(entity, MockSink())
await collector.prepare()

assert not collector._client

err_content = ",".join(ex.value.args)
assert "failed to connect" in err_content
assert "connect" in err_content


@pytest.mark.asyncio
Expand All @@ -121,7 +129,7 @@ def raiser():
# mock raising exception on connect attempts to test err handling
raise redis.ConnectionError("mock connection failure")

collector = DbMemoryCollector(entity, LogSink())
collector = DbMemoryCollector(entity, MockSink())
with monkeypatch.context() as ctx:
ctx.setattr(redis, "Redis", raiser)
with pytest.raises(SmartSimError) as ex:
Expand All @@ -130,7 +138,7 @@ def raiser():
assert not collector._client

err_content = ",".join(ex.value.args)
assert "failed to communicate" in err_content
assert "communicate" in err_content


@pytest.mark.asyncio
Expand All @@ -140,7 +148,7 @@ async def test_dbmemcollector_collect(
"""Ensure that a valid response is returned as expected"""
entity = mock_entity()

collector = DbMemoryCollector(entity, LogSink())
collector = DbMemoryCollector(entity, MockSink())
with monkeypatch.context() as ctx:
m1, m2, m3 = 12345, 23456, 34567
mock_data = {
Expand All @@ -163,7 +171,7 @@ async def test_dbmemcollector_integration(mock_entity, local_db):
output data matches expectations and proper db client API uage"""
entity = mock_entity(port=local_db.ports[0])

collector = DbMemoryCollector(entity, LogSink())
collector = DbMemoryCollector(entity, MockSink())

await collector.prepare()
await collector.collect()
Expand All @@ -182,7 +190,7 @@ async def test_dbconncollector_collect(
"""Ensure that a valid response is returned as expected"""
entity = mock_entity()

collector = DbConnectionCollector(entity, LogSink())
collector = DbConnectionCollector(entity, MockSink())
with monkeypatch.context() as ctx:
a1, a2 = "127.0.0.1:1234", "127.0.0.1:2345"
mock_data = [
Expand All @@ -209,7 +217,7 @@ async def test_dbconncollector_integration(mock_entity, local_db):
output data matches expectations and proper db client API uage"""
entity = mock_entity(port=local_db.ports[0])

collector = DbConnectionCollector(entity, LogSink())
collector = DbConnectionCollector(entity, MockSink())

await collector.prepare()
await collector.collect()
Expand All @@ -223,8 +231,8 @@ def test_collector_manager_add(mock_entity):
"""Ensure that collector manager add & clear work as expected"""
entity1 = mock_entity()

con_col = DbConnectionCollector(entity1, LogSink())
mem_col = DbMemoryCollector(entity1, LogSink())
con_col = DbConnectionCollector(entity1, MockSink())
mem_col = DbMemoryCollector(entity1, MockSink())

manager = CollectorManager()

Expand All @@ -245,7 +253,7 @@ def test_collector_manager_add(mock_entity):

# create a collector for another entity
entity2 = mock_entity()
con_col2 = DbConnectionCollector(entity2, LogSink())
con_col2 = DbConnectionCollector(entity2, MockSink())

# ensure collectors w/same type for new entities are not treated as dupes
manager.add(con_col2)
Expand All @@ -267,8 +275,8 @@ def test_collector_manager_add_multi(mock_entity):
"""Ensure that collector manager multi-add works as expected"""
entity = mock_entity()

con_col = DbConnectionCollector(entity, LogSink())
mem_col = DbMemoryCollector(entity, LogSink())
con_col = DbConnectionCollector(entity, MockSink())
mem_col = DbMemoryCollector(entity, MockSink())
manager = CollectorManager()

# add multiple items at once
Expand All @@ -277,8 +285,8 @@ def test_collector_manager_add_multi(mock_entity):
assert len(manager.all_collectors) == 2

# ensure multi-add does not produce dupes
con_col2 = DbConnectionCollector(entity, LogSink())
mem_col2 = DbMemoryCollector(entity, LogSink())
con_col2 = DbConnectionCollector(entity, MockSink())
mem_col2 = DbMemoryCollector(entity, MockSink())

manager.add_all([con_col2, mem_col2])
assert len(manager.all_collectors) == 2
Expand All @@ -290,17 +298,19 @@ async def test_collector_manager_collect(mock_entity, local_db):
entity1 = mock_entity(port=local_db.ports[0])
entity2 = mock_entity(port=local_db.ports[0])

con_col1 = DbConnectionCollector(entity1, LogSink())
mem_col1 = DbMemoryCollector(entity1, LogSink())
mem_col2 = DbMemoryCollector(entity2, LogSink())
# todo: consider a MockSink so i don't have to save the last value in the collector
s1, s2, s3 = MockSink(), MockSink(), MockSink()
con_col1 = DbConnectionCollector(entity1, s1)
mem_col1 = DbMemoryCollector(entity1, s2)
mem_col2 = DbMemoryCollector(entity2, s3)

manager = CollectorManager()
manager.add_all([con_col1, mem_col1, mem_col2])

# Execute collection
await manager.collect()

# verify each collector retrieved some metric
# verify each collector retrieved some metric & sent it to the sink
for collector in manager.all_collectors:
value = collector._value
value = t.cast(MockSink, collector.sink).args
assert value is not None and value

0 comments on commit 640d8b5

Please sign in to comment.