Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TESTS_PATH := tests
-include .env

ifndef UV_VERSION
UV_VERSION := 0.9.14
UV_VERSION := 0.9.24
endif

.PHONY: uv_check venv sync update format lint test docs docs-server docs-format docs-lint release
Expand Down
5 changes: 5 additions & 0 deletions src/draive/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from draive.aws.client import AWS
from draive.aws.observability import CloudwatchObservability
from draive.aws.state import AWSSQS, AWSCloudwatch
from draive.aws.types import AWSAccessDenied, AWSError, AWSResourceNotFound

__all__ = (
"AWS",
"AWSSQS",
"AWSAccessDenied",
"AWSCloudwatch",
"AWSError",
"AWSResourceNotFound",
"CloudwatchObservability",
)
22 changes: 20 additions & 2 deletions src/draive/aws/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
class AWSAPI:
"""Low-level AWS session and client management.

Provides an asynchronous S3 and SQS client initializers that other mixins
can rely on without duplicating boto3 session wiring.
Provides asynchronous client initializers for AWS services so higher-level
mixins can share a single boto3 session without duplicating configuration.
"""

__slots__ = (
"_cloudwatch_client",
"_cloudwatch_logs_client",
"_eventbridge_client",
"_s3_client",
"_session",
"_sqs_client",
Expand Down Expand Up @@ -52,6 +55,9 @@ def __init__(
kwargs["region_name"] = region

self._session: Session = Session(**kwargs)
self._cloudwatch_client: Any
self._cloudwatch_logs_client: Any
self._eventbridge_client: Any
self._s3_client: Any
self._sqs_client: Any

Expand All @@ -67,6 +73,18 @@ def _prepare_sqs_client(self) -> None:
service_name="sqs",
)

@asynchronous
def _prepare_cloudwatch_clients(self) -> None:
self._cloudwatch_client = self._session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
service_name="cloudwatch",
)
self._cloudwatch_logs_client = self._session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
service_name="logs",
)
self._eventbridge_client = self._session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
service_name="events",
)

@property
def region(self) -> str | None:
"""Currently configured AWS region for the session."""
Expand Down
21 changes: 17 additions & 4 deletions src/draive/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from haiway import State

from draive.aws.api import AWSAPI
from draive.aws.cloudwatch import AWSCloudwatchMixin
from draive.aws.s3 import AWSS3Mixin
from draive.aws.sqs import AWSSQSMixin
from draive.aws.state import AWSSQS
from draive.aws.state import AWSSQS, AWSCloudwatch
from draive.resources import ResourcesRepository

__all__ = ("AWS",)
Expand All @@ -17,9 +18,10 @@
class AWS(
AWSS3Mixin,
AWSSQSMixin,
AWSCloudwatchMixin,
AWSAPI,
):
"""AWS service facade bundling S3 and SQS integrations.
"""AWS service facade bundling S3, SQS, and CloudWatch integrations.

Parameters
----------
Expand All @@ -45,15 +47,15 @@ def __init__(
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
features: Collection[type[ResourcesRepository | AWSSQS]] | None = None,
features: Collection[type[ResourcesRepository | AWSSQS | AWSCloudwatch]] | None = None,
) -> None:
super().__init__(
region_name=region_name,
access_key_id=access_key_id,
secret_access_key=secret_access_key,
)

self._features: Collection[type[ResourcesRepository | AWSSQS]]
self._features: Collection[type[ResourcesRepository | AWSSQS | AWSCloudwatch]]
if features is not None:
self._features = features

Expand All @@ -79,6 +81,17 @@ async def __aenter__(self) -> Iterable[State]:
AWSSQS(queue_accessing=self._queue_access),
)

if AWSCloudwatch in self._features:
await self._prepare_cloudwatch_clients()

features.append(
AWSCloudwatch(
log_putting=self.put_log,
metric_putting=self.put_metric,
event_putting=self.put_event,
),
)

return features

async def __aexit__(
Expand Down
Loading