Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "uv_build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.94.1"
version = "0.95.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand All @@ -24,7 +24,7 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
license = { file = "LICENSE" }
dependencies = ["numpy~=2.3", "haiway~=0.39.2"]
dependencies = ["numpy~=2.3", "haiway~=0.40.0"]

[project.urls]
Homepage = "https://miquido.com"
Expand Down Expand Up @@ -52,6 +52,7 @@ opentelemetry = [
]
httpx = ["haiway[httpx]", "httpx"]
postgres = ["haiway[postgres]", "asyncpg"]
rabbitmq = ["haiway[rabbitmq]", "pika"]
qdrant = ["qdrant-client~=1.15.0"]
dev = [
"bandit~=1.8",
Expand Down
12 changes: 10 additions & 2 deletions src/draive/aws/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
class AWSAPI:
"""Low-level AWS session and client management.

Provides an asynchronous S3 client initializer that other mixins
Provides an asynchronous S3 and SQS client initializers that other mixins
can rely on without duplicating boto3 session wiring.
"""

__slots__ = (
"_s3_client",
"_session",
"_sqs_client",
)

def __init__(
Expand Down Expand Up @@ -52,13 +53,20 @@ def __init__(

self._session: Session = Session(**kwargs)
self._s3_client: Any
self._sqs_client: Any

@asynchronous
def _prepare_client(self) -> None:
def _prepare_s3_client(self) -> None:
self._s3_client = self._session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
service_name="s3",
)

@asynchronous
def _prepare_sqs_client(self) -> None:
self._sqs_client = self._session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
service_name="sqs",
)

@property
def region(self) -> str | None:
"""Currently configured AWS region for the session."""
Expand Down
31 changes: 21 additions & 10 deletions src/draive/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from draive.aws.api import AWSAPI
from draive.aws.s3 import AWSS3Mixin
from draive.aws.sqs import AWSSQSMixin
from draive.aws.state import AWSSQS
from draive.resources import ResourcesRepository

__all__ = ("AWS",)
Expand All @@ -14,9 +16,10 @@
@final
class AWS(
AWSS3Mixin,
AWSSQSMixin,
AWSAPI,
):
"""AWS service facade bundling S3 and repository integrations.
"""AWS service facade bundling S3 and SQS integrations.

Parameters
----------
Expand All @@ -30,7 +33,8 @@ class AWS(
Secret key paired with ``access_key_id`` when overriding
credentials.
features
Collection of repository feature classes to activate while the
Collection of feature state classes (for example
:class:`ResourcesRepository`, :class:`AWSSQS`) to activate while the
client is bound in a context manager.
"""

Expand All @@ -41,39 +45,46 @@ def __init__(
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
features: Collection[type[ResourcesRepository]] | None = None,
features: Collection[type[ResourcesRepository | AWSSQS]] | None = None,
) -> None:
super().__init__(
region_name=region_name,
access_key_id=access_key_id,
secret_access_key=secret_access_key,
)

self._features: Collection[type[ResourcesRepository]]
self._features: Collection[type[ResourcesRepository | AWSSQS]]
if features is not None:
self._features = features

else:
self._features = (ResourcesRepository,)
self._features = ()

async def __aenter__(self) -> Iterable[State]:
"""Prepare the AWS client and bind selected features to context."""
await self._prepare_client()
features: list[State] = []

if ResourcesRepository in self._features:
return (
await self._prepare_s3_client()
features.append(
ResourcesRepository(
fetching=self.fetch,
uploading=self.upload,
),
)

return ()
if AWSSQS in self._features:
await self._prepare_sqs_client()

features.append(
AWSSQS(queue_accessing=self._queue_access),
)

return features

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""No-op cleanup to satisfy the async context manager protocol."""
pass
94 changes: 43 additions & 51 deletions src/draive/aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@


class AWSS3Mixin(AWSAPI):
"""S3 helper mixin that implements fetch, download, and upload APIs."""

async def fetch(
self,
uri: str,
Expand Down Expand Up @@ -177,9 +175,6 @@ def _get_object_info(
else:
return response.get("ContentType"), Meta.of(response.get("Metadata"))

# Satisfy the type checker; control flow always leaves in try/except/else.
return None, META_EMPTY

async def upload(
self,
uri: str,
Expand Down Expand Up @@ -254,6 +249,49 @@ def _upload(
) from exc


def _sanitize_metadata_value(value: Any) -> str:
# Convert to string first
text: str = str(value)

# Replace newlines and tabs with spaces
text = re.sub(r"[\n\r\t]+", " ", text)

# Remove any control characters
text = re.sub(r"[\x00-\x1F\x7F-\x9F]", "", text)

# Collapse multiple spaces
text = re.sub(r"\s+", " ", text).strip()

# S3 metadata value limit is 1,024 bytes per value
max_bytes = 1021 # Reserve 3 bytes for "..."
text_bytes = text.encode("utf-8")

if len(text_bytes) > max_bytes:
# Truncate at byte level and ensure valid UTF-8
truncated_bytes = text_bytes[:max_bytes]
# Decode with 'ignore' to handle partial characters at the end
text = truncated_bytes.decode("utf-8", "ignore") + "..."

return text


def _sanitize_metadata(
meta: Mapping[str, BasicValue] | None,
) -> dict[str, str]:
if not meta:
return {}

sanitized: dict[str, str] = {}
for key, value in meta.items():
# Sanitize both key and value
sanitized_key = _sanitize_metadata_value(key)
sanitized_value = _sanitize_metadata_value(value)
if sanitized_key and sanitized_value: # Skip empty keys/values
sanitized[sanitized_key] = sanitized_value

return sanitized


def _translate_client_error(
*,
error: ClientError,
Expand Down Expand Up @@ -305,49 +343,3 @@ def _translate_client_error(
code=code or None,
message=message,
)


def _sanitize_metadata_value(value: Any) -> str:
# Convert to string first
text: str = str(value)

# Remove non-ASCII characters
text = text.encode("ascii", "ignore").decode("ascii")

# Replace newlines and tabs with spaces
text = re.sub(r"[\n\r\t]+", " ", text)

# Remove any control characters
text = re.sub(r"[\x00-\x1F\x7F-\x9F]", "", text)

# Collapse multiple spaces
text = re.sub(r"\s+", " ", text).strip()

# S3 metadata value limit is 1,024 bytes per value
max_bytes = 1021 # Reserve 3 bytes for "..."
text_bytes = text.encode("utf-8")

if len(text_bytes) > max_bytes:
# Truncate at byte level and ensure valid UTF-8
truncated_bytes = text_bytes[:max_bytes]
# Decode with 'ignore' to handle partial characters at the end
text = truncated_bytes.decode("utf-8", "ignore") + "..."

return text


def _sanitize_metadata(
meta: Mapping[str, BasicValue] | None,
) -> dict[str, str]:
if not meta:
return {}

sanitized: dict[str, str] = {}
for key, value in meta.items():
# Sanitize both key and value
sanitized_key = _sanitize_metadata_value(key)
sanitized_value = _sanitize_metadata_value(value)
if sanitized_key and sanitized_value: # Skip empty keys/values
sanitized[sanitized_key] = sanitized_value

return sanitized
Loading