Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/draive/aws/cloudwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _put_log(
)

except ClientError as exc:
raise _translate_cloudwatch_error(
raise translate_cloudwatch_error(
error=exc,
service="logs",
operation="put_log",
Expand Down Expand Up @@ -110,7 +110,7 @@ def _put_event(
)

except ClientError as exc:
raise _translate_cloudwatch_error(
raise translate_cloudwatch_error(
error=exc,
service="events",
operation="put_events",
Expand Down Expand Up @@ -145,7 +145,7 @@ def _put_metric(
attributes: Mapping[str, ObservabilityAttribute],
) -> None:
try:
dimensions = _format_metric_dimensions(attributes)
dimensions = format_metric_dimensions(attributes)
metric_data: dict[str, Any] = {
"MetricName": metric,
"Value": value,
Expand All @@ -163,15 +163,15 @@ def _put_metric(
)

except ClientError as exc:
raise _translate_cloudwatch_error(
raise translate_cloudwatch_error(
error=exc,
service="cloudwatch",
operation="put_metric_data",
resource=namespace,
) from exc


def _format_metric_dimensions(
def format_metric_dimensions(
attributes: Mapping[str, ObservabilityAttribute],
) -> list[dict[str, str]]:
dimensions: list[dict[str, str]] = []
Expand Down Expand Up @@ -246,7 +246,7 @@ def _truncate_dimension(
return value[:max_length]


def _translate_cloudwatch_error(
def translate_cloudwatch_error(
*,
error: ClientError,
service: str,
Expand Down
157 changes: 149 additions & 8 deletions src/draive/aws/observability.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import json
import traceback
from collections.abc import Mapping, MutableMapping, MutableSequence, Sequence
from datetime import UTC, datetime
from time import monotonic, time_ns
from typing import Any, Self, cast
from uuid import UUID, uuid4

from boto3 import Session # pyright: ignore[reportMissingTypeStubs]
from botocore.exceptions import ClientError # pyright: ignore[reportMissingModuleSource]
from haiway import (
MISSING,
ContextIdentifier,
Expand All @@ -15,13 +19,22 @@
)
from haiway.context import ObservabilityMetricKind

from draive.aws.state import AWSCloudwatch
from draive.aws.cloudwatch import format_metric_dimensions, translate_cloudwatch_error
from draive.aws.types import (
AWSCloudwatchEventPutting,
AWSCloudwatchLogPutting,
AWSCloudwatchMetricPutting,
)

__all__ = ("CloudwatchObservability",)


def CloudwatchObservability( # noqa: C901, PLR0915
*,
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
profile_name: str | None = None,
log_level: ObservabilityLevel,
log_group: str,
log_stream: str,
Expand All @@ -33,14 +46,129 @@ def CloudwatchObservability( # noqa: C901, PLR0915
root_scope: ContextIdentifier | None = None
scopes: MutableMapping[UUID, ScopeStore] = {}

session_kwargs: MutableMapping[str, object] = {}
if access_key_id:
session_kwargs["aws_access_key_id"] = access_key_id

if secret_access_key:
session_kwargs["aws_secret_access_key"] = secret_access_key

if region_name:
session_kwargs["region_name"] = region_name

if profile_name:
session_kwargs["profile_name"] = profile_name

session: Session = Session(**session_kwargs)
cloudwatch_client: Any = session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
"cloudwatch",
)
cloudwatch_logs_client: Any = session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
"logs",
)
eventbridge_client: Any = session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
"events",
)

async def log_putting(
*,
log_group: str,
log_stream: str,
message: str,
) -> None:
try:
await asyncio.to_thread(
cloudwatch_logs_client.put_log_events, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
logGroupName=log_group,
logStreamName=log_stream,
logEvents=[
{
"timestamp": time_ns() // 1_000_000,
"message": message,
}
],
)

except ClientError as exc:
raise translate_cloudwatch_error(
error=exc,
service="logs",
operation="put_log",
resource=f"{log_group}/{log_stream}",
) from exc

async def metric_putting(
*,
namespace: str,
metric: str,
value: float | int,
unit: str | None,
attributes: Mapping[str, ObservabilityAttribute],
) -> None:
try:
dimensions = format_metric_dimensions(attributes)
metric_data: dict[str, object] = {
"MetricName": metric,
"Value": value,
}

if dimensions:
metric_data["Dimensions"] = dimensions

if unit:
metric_data["Unit"] = unit

await asyncio.to_thread(
cloudwatch_client.put_metric_data, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
Namespace=namespace,
MetricData=[metric_data],
)

except ClientError as exc:
raise translate_cloudwatch_error(
error=exc,
service="cloudwatch",
operation="put_metric_data",
resource=namespace,
) from exc

async def event_putting(
*,
event_bus: str,
event_source: str,
detail_type: str,
detail: str,
) -> None:
try:
await asyncio.to_thread(
eventbridge_client.put_events, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
Entries=[
{
"EventBusName": event_bus,
"Source": event_source,
"DetailType": detail_type,
"Detail": detail,
"Time": datetime.now(tz=UTC),
}
],
)

except ClientError as exc:
raise translate_cloudwatch_error(
error=exc,
service="events",
operation="put_events",
resource=event_bus,
) from exc

def trace_identifying(
scope: ContextIdentifier,
/,
) -> UUID:
assert root_scope is not None # nosec: B101
assert scope.scope_id in scopes # nosec: B101

return root_scope.scope_id
return trace_id

def log_recording(
scope: ContextIdentifier,
Expand Down Expand Up @@ -153,6 +281,9 @@ def scope_entering(
event_bus=event_bus,
event_source=event_source,
metrics_namespace=metrics_namespace,
log_putting=log_putting,
metric_putting=metric_putting,
event_putting=event_putting,
)
root_scope = scope

Expand Down Expand Up @@ -216,10 +347,13 @@ class ScopeStore:
"_end_time",
"_start_time",
"event_bus",
"event_putting",
"event_source",
"identifier",
"log_group",
"log_putting",
"log_stream",
"metric_putting",
"metrics_namespace",
"nested",
"trace_id",
Expand All @@ -236,6 +370,9 @@ def __init__(
event_bus: str,
event_source: str,
metrics_namespace: str,
log_putting: AWSCloudwatchLogPutting,
metric_putting: AWSCloudwatchMetricPutting,
event_putting: AWSCloudwatchEventPutting,
) -> None:
self.identifier: ContextIdentifier = identifier
self.trace_id: UUID = trace_id
Expand All @@ -244,6 +381,9 @@ def __init__(
self.event_bus: str = event_bus
self.event_source: str = event_source
self.metrics_namespace: str = metrics_namespace
self.log_putting: AWSCloudwatchLogPutting = log_putting
self.metric_putting: AWSCloudwatchMetricPutting = metric_putting
self.event_putting: AWSCloudwatchEventPutting = event_putting
self.nested: MutableSequence[ScopeStore] = []
self._start_time: float = monotonic()
self._end_time: float | None = None
Expand All @@ -262,6 +402,9 @@ def child(
event_bus=self.event_bus,
event_source=self.event_source,
metrics_namespace=self.metrics_namespace,
log_putting=self.log_putting,
metric_putting=self.metric_putting,
event_putting=self.event_putting,
)
self.nested.append(child)
return child
Expand Down Expand Up @@ -312,7 +455,7 @@ def record_log(
attributes["exception.message"] = str(exception)

ctx.spawn_background(
AWSCloudwatch.put_log,
self.log_putting,
log_stream=self.log_stream,
log_group=self.log_group,
message=_json_dumps(
Expand All @@ -327,16 +470,14 @@ def record_log(
}
),
)
if exception is not None:
self.record_exception(exception)

def record_exception(
self,
exception: BaseException,
/,
) -> None:
ctx.spawn_background(
AWSCloudwatch.put_event,
self.event_putting,
event_bus=self.event_bus,
event_source=self.event_source,
detail_type="exception",
Expand Down Expand Up @@ -374,7 +515,7 @@ def record_event(
attributes: Mapping[str, ObservabilityAttribute],
) -> None:
ctx.spawn_background(
AWSCloudwatch.put_event,
self.event_putting,
event_bus=self.event_bus,
event_source=self.event_source,
detail_type="event",
Expand Down Expand Up @@ -412,7 +553,7 @@ def record_metric(
)
metric_attributes.setdefault("otel.metric.kind", kind)
ctx.spawn_background(
AWSCloudwatch.put_metric,
self.metric_putting,
namespace=self.metrics_namespace,
metric=name,
value=value,
Expand Down
7 changes: 2 additions & 5 deletions src/draive/qdrant/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from asyncio import get_running_loop
import asyncio
from collections.abc import Callable, Iterable, Sequence
from contextvars import copy_context
from typing import Any, Literal, cast, overload
from uuid import uuid4

Expand Down Expand Up @@ -201,9 +200,7 @@ async def store[Model: DataModel](
parallel_tasks: int = 1,
**extra: Any,
) -> None:
await get_running_loop().run_in_executor(
None,
copy_context().run,
await asyncio.to_thread(
self._partial_store(
model,
objects=objects,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_aws_cloudwatch_observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_format_metric_dimensions_respects_sequence_values(
"skip": None,
}

dimensions = cloudwatch._format_metric_dimensions(attributes)
dimensions = cloudwatch.format_metric_dimensions(attributes)

assert dimensions == [
{"Name": "service", "Value": "draive"},
Expand All @@ -86,7 +86,7 @@ def test_translate_cloudwatch_error_includes_code(
}
error = types.SimpleNamespace(response=response)

exc = cloudwatch._translate_cloudwatch_error(
exc = cloudwatch.translate_cloudwatch_error(
error=error,
service="events",
operation="put_events",
Expand All @@ -107,7 +107,7 @@ def test_translate_cloudwatch_error_access_denied(
}
error = types.SimpleNamespace(response=response)

exc = cloudwatch._translate_cloudwatch_error(
exc = cloudwatch.translate_cloudwatch_error(
error=error,
service="events",
operation="put_events",
Expand Down
Loading