diff --git a/Makefile b/Makefile index 5f11ea8..50115a6 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ -.PHONY: lint deps test +.PHONY: lint deps test test-verbose all: lint test lint: deps - pipenv run flake8 --ignore=E501 . + pipenv run flake8 --ignore=E501,W503 . pipenv run mypy --strict . pipenv run black --check . @@ -11,4 +11,7 @@ deps: pipenv sync --dev test: deps - pipenv run pytest tests.py + pipenv run pytest tests.py -sv + +test-verbose: deps + pipenv run pytest tests.py -v -o log_cli=true --capture=fd --show-capture=stderr --log-level=DEBUG diff --git a/ssv_cluster_exporter.py b/ssv_cluster_exporter.py index 1fa5ad4..c23fe3b 100644 --- a/ssv_cluster_exporter.py +++ b/ssv_cluster_exporter.py @@ -2,6 +2,7 @@ import asyncio import copy import enum +import functools import json import logging import pathlib @@ -12,7 +13,7 @@ import furl # type: ignore[import-untyped] from prometheus_async import aio from prometheus_client import Gauge -from pydantic import AfterValidator, BaseModel, computed_field +from pydantic import AfterValidator, BaseModel, ConfigDict, computed_field from pydantic_settings import BaseSettings from web3 import Web3 from web3.contract import AsyncContract @@ -42,11 +43,29 @@ def ssv_network_views_contract(self) -> str: raise RuntimeError("Can not derive SSV network views address for network") -# ################### +# #################### # Aiohttp & web3 apps -def get_application() -> web.Application: +async def start_exporter_app(app: web.Application) -> None: + exporter: SSVClusterExporter = app[exporter_app_key] + # Reuse client session for web3 and ssv api + await exporter.ethereum_rpc.provider.cache_async_session(exporter.session) # type: ignore[attr-defined] + # Acquire data once to verify its working + await exporter.tick() + # Spawn long-running process + exporter.start() + + +async def stop_exporter_app(app: web.Application) -> None: + exporter: SSVClusterExporter = app[exporter_app_key] + await exporter.stop() + + +def get_application(exporter: "SSVClusterExporter") -> web.Application: app = web.Application() + app[exporter_app_key] = exporter app.router.add_get("/metrics", aio.web.server_stats) + app.on_startup.append(start_exporter_app) + app.on_shutdown.append(stop_exporter_app) return app @@ -209,10 +228,9 @@ class SSVNetworkProperties(BaseModel): class SSVNetworkContract(BaseModel): """A facade for web3 contract data retrieval for network wide values.""" - network_views: AsyncContract + model_config = ConfigDict(arbitrary_types_allowed=True) - class Config: - arbitrary_types_allowed = True + network_views: AsyncContract async def fetch_network_fee(self) -> int: return int(await self.network_views.functions.getNetworkFee().call()) @@ -243,11 +261,11 @@ async def fetch_all(self) -> SSVNetworkProperties: class SSVClusterContract(BaseModel): """A facade for web3 contract data retrieval for clusters.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + network_views: AsyncContract clusters: set[SSVCluster] - - class Config: - arbitrary_types_allowed = True + loop: asyncio.AbstractEventLoop def contract_call_args(self, cluster: SSVCluster) -> SSVNetworkViewsCallArgs: return ( @@ -277,13 +295,13 @@ async def get_cluster_burn_rate(self, cluster: SSVCluster) -> None: async def fetch_balances(self) -> None: futs = [] for cluster in self.clusters: - futs.append(asyncio.create_task(self.get_cluster_balance(cluster))) + futs.append(self.loop.create_task(self.get_cluster_balance(cluster))) await asyncio.gather(*futs) async def fetch_burn_rates(self) -> None: futs = [] for cluster in self.clusters: - futs.append(asyncio.create_task(self.get_cluster_burn_rate(cluster))) + futs.append(self.loop.create_task(self.get_cluster_burn_rate(cluster))) await asyncio.gather(*futs) async def fetch_all(self) -> None: @@ -304,13 +322,39 @@ class SSVClusterExporter(BaseSettings): ethereum_rpc: Web3RpcClient base_ssv_url: furl.furl = furl.furl("https://api.ssv.network/api/v4/") - session: client.ClientSession + loop: asyncio.AbstractEventLoop + + # Stopping + stopping: bool = False + stopped: asyncio.Event = asyncio.Event() @computed_field # type: ignore - @property + @functools.cached_property def network_views(self) -> AsyncContract: return get_ssv_network_views_contract(self.ethereum_rpc, self.network) # type: ignore[arg-type] + @computed_field # type: ignore + @functools.cached_property + def session(self) -> client.ClientSession: + return client.ClientSession(loop=self.loop) + + def on_runner_task_done(self, *args: typing.Any) -> None: + self.stopped.set() + + def start(self) -> None: + self._runner_task = self.loop.create_task(self.run()) + # Raise event when task is stopped + self._runner_task.add_done_callback(self.on_runner_task_done) + + async def stop(self) -> None: + logger.info("Gracefully shutting down application") + self.stopping = True + self._runner_task.cancel() + if not self.stopped.is_set(): + await self.stopped.wait() + await self.session.close() + logger.info("Stopped components, will exit") + async def sleep(self) -> None: await asyncio.sleep(self.interval_ms / 1000) @@ -390,11 +434,11 @@ async def fetch_clusters_info(self) -> list[SSVCluster]: for owner_config in self.owners: futs.append( - asyncio.create_task(self.get_owner_clusters(owner_config.address)) + self.loop.create_task(self.get_owner_clusters(owner_config.address)) ) for cluster_config in self.clusters: futs.append( - asyncio.create_task(self.get_cluster_by_id(cluster_config.cluster_id)) + self.loop.create_task(self.get_cluster_by_id(cluster_config.cluster_id)) ) responses = await asyncio.gather(*futs) @@ -434,7 +478,9 @@ async def clusters_updates(self) -> None: """Run cluster-specific metrics update.""" clusters = set(await self.fetch_clusters_info()) latest_metric_fetcher = SSVClusterContract( - network_views=self.network_views, clusters=clusters + network_views=self.network_views, + clusters=clusters, + loop=self.loop, ) await latest_metric_fetcher.fetch_all() self.update_clusters_metrics(*clusters) @@ -454,15 +500,23 @@ async def tick(self) -> None: ) except Exception: logger.exception("Failed to update cluster details") + if self.stopping: + await self.session.close() - async def loop(self) -> None: + async def run(self) -> None: """Infinite loop that spawns checker tasks.""" - while True: - asyncio.ensure_future(self.tick()) + while not self.stopping: + self.loop.create_task(self.tick()) await self.sleep() self.stopped.set() +# Aiohttp app key for exporter component +exporter_app_key: web.AppKey[SSVClusterExporter] = web.AppKey( + "exporter", SSVClusterExporter +) + + # ############# # Entry point def main() -> None: @@ -476,7 +530,7 @@ def main() -> None: asyncio.set_event_loop(loop) try: config_data = yaml.safe_load(config_text) - config_data["session"] = client.ClientSession(loop=loop) + config_data["loop"] = loop exporter = SSVClusterExporter(**config_data) except yaml.error.YAMLError: logger.exception("Invalid config YAML") @@ -485,9 +539,10 @@ def main() -> None: logger.exception("Invalid config data") exit(2) else: - app = get_application() - loop.create_task(exporter.loop()) - web.run_app(app, host=args.host, port=args.port, loop=loop) + app = get_application(exporter) + web.run_app( + app, host=args.host, port=args.port, loop=loop, handler_cancellation=True + ) if __name__ == "__main__": diff --git a/tests.py b/tests.py index 343e41e..b05448a 100644 --- a/tests.py +++ b/tests.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import AsyncGenerator import socket import typing @@ -18,17 +19,16 @@ def find_free_port() -> int: @pytest_asyncio.fixture async def metrics_server(exporter_data: typing.Any) -> AsyncGenerator[str, None]: - exporter_data["session"] = client.ClientSession() + exporter_data["loop"] = asyncio.get_event_loop() exporter = ssv_cluster_exporter.SSVClusterExporter(**exporter_data) port = find_free_port() - # Acquire data once - await exporter.tick() - app = ssv_cluster_exporter.get_application() + app = ssv_cluster_exporter.get_application(exporter) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "localhost", port) await site.start() yield f"http://localhost:{port}" + await runner.shutdown() await site.stop() @@ -59,32 +59,33 @@ async def metrics_server(exporter_data: typing.Any) -> AsyncGenerator[str, None] ], ) async def test_metrics(metrics_server: str) -> None: - session = client.ClientSession() - response = await session.get(f"{metrics_server}/metrics") - assert response.status == 200 matched_metrics = set() - for metric in text_string_to_metric_families(await response.text()): - if metric.name.startswith("ssv_cluster"): - sample = metric.samples[0] - assert ( - sample.labels["cluster_id"] - == "0xde12c5ce1bc895c3ed8b81afcbbb55b3efff7ae9ebac5dbd2ebac3bd29474c09" # noqa: W503 - ) - assert sample.labels["id"] == "1278541" - assert sample.labels["network"] == "holesky" - assert sample.labels["operators"] == "1092,1093,1094,1095" - assert ( - sample.labels["owner"] == "0xD4BB555d3B0D7fF17c606161B44E372689C14F4B" - ) - matched_metrics.add(metric.name) - elif metric.name in ( - "ssv_network_fee", - "ssv_minimum_liquidation_collateral", - "ssv_liquidation_threshold_period", - ): - sample = metric.samples[0] - assert sample.labels["network"] == "holesky" - matched_metrics.add(metric.name) + async with client.ClientSession() as session: + response = await session.get(f"{metrics_server}/metrics") + assert response.status == 200 + for metric in text_string_to_metric_families(await response.text()): + if metric.name.startswith("ssv_cluster"): + sample = metric.samples[0] + assert ( + sample.labels["cluster_id"] + == "0xde12c5ce1bc895c3ed8b81afcbbb55b3efff7ae9ebac5dbd2ebac3bd29474c09" # noqa: W503 + ) + assert sample.labels["id"] == "1278541" + assert sample.labels["network"] == "holesky" + assert sample.labels["operators"] == "1092,1093,1094,1095" + assert ( + sample.labels["owner"] + == "0xD4BB555d3B0D7fF17c606161B44E372689C14F4B" + ) + matched_metrics.add(metric.name) + elif metric.name in ( + "ssv_network_fee", + "ssv_minimum_liquidation_collateral", + "ssv_liquidation_threshold_period", + ): + sample = metric.samples[0] + assert sample.labels["network"] == "holesky" + matched_metrics.add(metric.name) assert matched_metrics == { "ssv_cluster_validators_count", "ssv_cluster_balance",