Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/guides/Basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async with ctx.scope("basics", basic_state):
do_something_contextually()

print("Final:")
# when leaving the updated scope we go back to previously defined state
# when leaving the updating scope we go back to previously defined state
do_something_contextually()

# do_something_contextually() # calling it outside of any context scope will cause an error
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "uv_build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.97.0"
version = "0.98.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand All @@ -24,7 +24,7 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
license = { file = "LICENSE" }
dependencies = ["numpy~=2.3", "haiway~=0.42"]
dependencies = ["numpy~=2.3", "haiway~=0.42.0"]

[project.urls]
Homepage = "https://miquido.com"
Expand Down
1 change: 1 addition & 0 deletions src/draive/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def __aenter__(self) -> Iterable[State]:
await self._prepare_s3_client()
features.append(
ResourcesRepository(
list_fetching=self.list,
fetching=self.fetch,
uploading=self.upload,
),
Expand Down
216 changes: 215 additions & 1 deletion src/draive/aws/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import mimetypes
import re
from asyncio import gather
from collections.abc import Collection, Mapping
from collections.abc import Collection, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from typing import Any
Expand All @@ -17,6 +18,63 @@


class AWSS3Mixin(AWSAPI):
async def list(
self,
*,
uri: str | None = None,
recursive: bool = True,
limit: int | None = None,
**extra: Any,
) -> Sequence[ResourceReference]:
"""List buckets or objects stored under an S3 bucket or prefix.

Parameters
----------
uri
Optional S3 URI pointing to the bucket or prefix
(``s3://bucket/prefix``). When omitted, lists buckets.
recursive
When ``True`` (default), list all objects under the prefix.
When ``False``, list only the immediate children and prefix markers.
limit
Optional maximum number of results to return.

Returns
-------
Sequence[ResourceReference]
References to objects and prefix markers under the requested scope.

Raises
------
ValueError
If the ``uri`` does not use the ``s3://`` scheme.
AWSAccessDenied
If S3 rejects the request due to missing or invalid credentials.
AWSResourceNotFound
If the bucket does not exist.
AWSError
For other S3 client failures.
"""
if uri is None:
return await self._list_buckets(
limit=limit,
**extra,
)

if not uri.startswith("s3://"):
raise ValueError("Unsupported list uri scheme")

parsed_uri: ParseResult = urlparse(uri) # s3://bucket/prefix
bucket: str = parsed_uri.netloc
prefix: str = parsed_uri.path.lstrip("/")
return await self._list(
bucket=bucket,
prefix=prefix,
recursive=recursive,
limit=limit,
**extra,
)

async def fetch(
self,
uri: str,
Expand Down Expand Up @@ -223,6 +281,113 @@ async def upload(

return META_EMPTY

@asynchronous
def _list( # noqa: C901, PLR0912
self,
*,
bucket: str,
prefix: str,
recursive: bool,
limit: int | None,
**extra: Any,
) -> Sequence[ResourceReference]:
try:
paginator = self._s3_client.get_paginator("list_objects_v2")
pagination_config: dict[str, int] = {}
if limit is not None:
pagination_config["MaxItems"] = limit
pagination_config["PageSize"] = min(1000, limit)

request: dict[str, Any] = {
"Bucket": bucket,
"Prefix": prefix,
}
if not recursive:
request["Delimiter"] = "/"

if pagination_config:
request["PaginationConfig"] = pagination_config

if extra:
request.update(extra)

references: list[ResourceReference] = []
for page in paginator.paginate(**request):
for entry in page.get("Contents", []):
if limit is not None and len(references) >= limit:
break

key = entry.get("Key", "")
if not key or (key == prefix and key.endswith("/")):
continue

references.append(
_object_reference(
bucket=bucket,
key=key,
entry=entry,
)
)

if limit is not None and len(references) >= limit:
break

for common_prefix in page.get("CommonPrefixes", []):
if limit is not None and len(references) >= limit:
break

prefix_key = common_prefix.get("Prefix", "")
if not prefix_key:
continue

references.append(
ResourceReference.of(
f"s3://{bucket}/{prefix_key}",
name=_object_name(prefix_key),
meta=Meta.of({"type": "prefix"}),
)
)

if limit is not None and len(references) >= limit:
break

return references

except ClientError as exc:
raise _translate_client_error(
error=exc,
bucket=bucket,
key=prefix,
) from exc

@asynchronous
def _list_buckets(
self,
*,
limit: int | None,
**extra: Any,
) -> Sequence[ResourceReference]:
try:
response: Mapping[str, Any] = self._s3_client.list_buckets(**extra)
buckets: Collection[Mapping[str, Any]] = response.get("Buckets", ())
references: list[ResourceReference] = [
_bucket_reference(bucket) for bucket in buckets if bucket.get("Name")
]
if limit is not None:
return references[:limit]

return references

except ClientError as exc:
error_info: Mapping[str, Any] = getattr(exc, "response", {}).get("Error", {})
code: str = str(error_info.get("Code") or "").strip()
message: str = str(error_info.get("Message") or str(exc)).strip()
raise AWSError(
uri="s3://",
code=code or None,
message=message,
) from exc

@asynchronous
def _upload(
self,
Expand Down Expand Up @@ -292,6 +457,55 @@ def _sanitize_metadata(
return sanitized


def _object_name(key: str) -> str | None:
trimmed = key.rstrip("/")
if not trimmed:
return None
return Path(trimmed).name


def _object_reference(
*,
bucket: str,
key: str,
entry: Mapping[str, Any],
) -> ResourceReference:
meta_values: dict[str, BasicValue] = {}
meta_values["type"] = "object"
if (etag := entry.get("ETag")) is not None:
meta_values["etag"] = str(etag).strip('"')
if (size := entry.get("Size")) is not None:
meta_values["size"] = int(size)
if (last_modified := entry.get("LastModified")) is not None:
meta_values["last_modified"] = str(last_modified)
if (storage_class := entry.get("StorageClass")) is not None:
meta_values["storage_class"] = str(storage_class)
if (mime_type := mimetypes.guess_type(key)[0]) is not None:
meta_values["mime_type"] = mime_type

return ResourceReference.of(
f"s3://{bucket}/{key}",
name=_object_name(key),
meta=Meta.of(meta_values),
)


def _bucket_reference(
bucket: Mapping[str, Any],
) -> ResourceReference:
name = str(bucket.get("Name") or "").strip()
meta_values: dict[str, BasicValue] = {}
meta_values["type"] = "bucket"
if creation_date := bucket.get("CreationDate"):
meta_values["creation_date"] = str(creation_date)

return ResourceReference.of(
f"s3://{name}",
name=name or None,
meta=Meta.of(meta_values),
)


def _translate_client_error(
*,
error: ClientError,
Expand Down
Loading