Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from typing import Any, overload

from elastic_transport import ObjectApiResponse # noqa: TC002
from key_value.shared.errors import DeserializationError
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json
from elastic_transport import ObjectApiResponse
from elastic_transport import SerializationError as ElasticsearchSerializationError
from key_value.shared.errors import DeserializationError, SerializationError
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict
from key_value.shared.utils.sanitize import (
ALPHANUMERIC_CHARACTERS,
LOWERCASE_ALPHABET,
Expand All @@ -22,7 +23,7 @@
BaseEnumerateKeysStore,
BaseStore,
)
from key_value.aio.stores.elasticsearch.utils import new_bulk_action
from key_value.aio.stores.elasticsearch.utils import LessCapableJsonSerializer, LessCapableNdjsonSerializer, new_bulk_action

try:
from elasticsearch import AsyncElasticsearch
Expand Down Expand Up @@ -55,10 +56,17 @@
"type": "keyword",
},
"value": {
"type": "keyword",
"index": False,
"doc_values": False,
"ignore_above": 256,
"properties": {
# You might think the `string` field should be a text/keyword field
# but this is the recommended mapping for large stringified json
"string": {
"type": "object",
"enabled": False,
},
"flattened": {
"type": "flattened",
},
},
},
},
}
Expand All @@ -73,12 +81,14 @@
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
document: dict[str, Any] = {
"collection": collection,
"key": key,
"value": managed_entry.to_json(include_metadata=False),
}
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]:
document: dict[str, Any] = {"collection": collection, "key": key, "value": {}}

# Store in appropriate field based on mode
if native_storage:
document["value"]["flattened"] = managed_entry.value_as_dict
else:
document["value"]["string"] = managed_entry.value_as_json

if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
Expand All @@ -89,15 +99,31 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE


def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
if not (value_str := source.get("value")) or not isinstance(value_str, str):
msg = "Value is not a string"
value: dict[str, Any] = {}

raw_value = source.get("value")

# Try flattened field first, fall back to string field
if not raw_value or not isinstance(raw_value, dict):
msg = "Value field not found or invalid type"
raise DeserializationError(msg)

if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
value = verify_dict(obj=value_flattened)
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
if not isinstance(value_str, str):
msg = "Value in `value` field is not a string"
raise DeserializationError(msg)
value = load_from_json(value_str)
else:
msg = "Value field not found or invalid type"
raise DeserializationError(msg)

created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))

return ManagedEntry(
value=load_from_json(value_str),
value=value,
created_at=created_at,
expires_at=expires_at,
)
Expand All @@ -114,11 +140,28 @@ class ElasticsearchStore(

_index_prefix: str

_native_storage: bool

@overload
def __init__(self, *, elasticsearch_client: AsyncElasticsearch, index_prefix: str, default_collection: str | None = None) -> None: ...
def __init__(
self,
*,
elasticsearch_client: AsyncElasticsearch,
index_prefix: str,
native_storage: bool = True,
default_collection: str | None = None,
) -> None: ...

@overload
def __init__(self, *, url: str, api_key: str | None = None, index_prefix: str, default_collection: str | None = None) -> None: ...
def __init__(
self,
*,
url: str,
api_key: str | None = None,
index_prefix: str,
native_storage: bool = True,
default_collection: str | None = None,
) -> None: ...

def __init__(
self,
Expand All @@ -127,6 +170,7 @@ def __init__(
url: str | None = None,
api_key: str | None = None,
index_prefix: str,
native_storage: bool = True,
default_collection: str | None = None,
) -> None:
"""Initialize the elasticsearch store.
Expand All @@ -136,6 +180,8 @@ def __init__(
url: The url of the elasticsearch cluster.
api_key: The api key to use.
index_prefix: The index prefix to use. Collections will be prefixed with this prefix.
native_storage: Whether to use native storage mode (flattened field type) or serialize
all values to JSON strings. Defaults to True.
default_collection: The default collection to use if no collection is provided.
"""
if elasticsearch_client is None and url is None:
Expand All @@ -152,7 +198,12 @@ def __init__(
msg = "Either elasticsearch_client or url must be provided"
raise ValueError(msg)

LessCapableJsonSerializer.install_serializer(client=self._client)
LessCapableJsonSerializer.install_default_serializer(client=self._client)
LessCapableNdjsonSerializer.install_serializer(client=self._client)

self._index_prefix = index_prefix
self._native_storage = native_storage
self._is_serverless = False

super().__init__(default_collection=default_collection)
Expand Down Expand Up @@ -205,18 +256,11 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
if not (source := get_source_from_body(body=body)):
return None

if not (value_str := source.get("value")) or not isinstance(value_str, str):
try:
return source_to_managed_entry(source=source)
except DeserializationError:
return None

created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))

return ManagedEntry(
value=load_from_json(value_str),
created_at=created_at,
expires_at=expires_at,
)

@override
async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]:
if not keys:
Expand Down Expand Up @@ -265,15 +309,23 @@ async def _put_managed_entry(
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)

document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)

_ = await self._client.index(
index=index_name,
id=document_id,
body=document,
refresh=self._should_refresh_on_put,
document: dict[str, Any] = managed_entry_to_document(
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
)

try:
_ = await self._client.index(
index=index_name,
id=document_id,
body=document,
refresh=self._should_refresh_on_put,
)
except ElasticsearchSerializationError as e:
msg = f"Failed to serialize document: {e}"
raise SerializationError(message=msg) from e
except Exception:
raise

@override
async def _put_managed_entries(
self,
Expand All @@ -297,11 +349,18 @@ async def _put_managed_entries(

index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)

document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
document: dict[str, Any] = managed_entry_to_document(
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
)

operations.extend([index_action, document])

_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
try:
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
except ElasticsearchSerializationError as e:
msg = f"Failed to serialize bulk operations: {e}"
raise SerializationError(message=msg) from e
except Exception:
raise

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Any, TypeVar, cast
from typing import Any, ClassVar, TypeVar, cast

from elastic_transport import ObjectApiResponse
from key_value.shared.utils.managed_entry import ManagedEntry
from elastic_transport import (
JsonSerializer,
NdjsonSerializer,
ObjectApiResponse,
SerializationError,
)

from elasticsearch import AsyncElasticsearch


def get_body_from_response(response: ObjectApiResponse[Any]) -> dict[str, Any]:
Expand All @@ -28,7 +34,10 @@ def get_aggregations_from_body(body: dict[str, Any]) -> dict[str, Any]:
if not (aggregations := body.get("aggregations")):
return {}

if not isinstance(aggregations, dict) or not all(isinstance(key, str) for key in aggregations): # pyright: ignore[reportUnknownVariableType]
if not isinstance(aggregations, dict) or not all(
isinstance(key, str)
for key in aggregations # pyright: ignore[reportUnknownVariableType]
):
return {}

return cast("dict[str, Any]", aggregations)
Expand Down Expand Up @@ -108,20 +117,50 @@ def get_first_value_from_field_in_hit(hit: dict[str, Any], field: str, value_typ
return values[0]


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
document: dict[str, Any] = {
"collection": collection,
"key": key,
"value": managed_entry.to_json(include_metadata=False),
}
def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
return {action: {"_index": index, "_id": document_id}}


if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at.isoformat()
class LessCapableJsonSerializer(JsonSerializer):
"""A JSON Serializer that doesnt try to be smart with datetime, floats, etc."""

return document
mimetype: ClassVar[str] = "application/json"
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+json"

def default(self, data: Any) -> Any:
raise SerializationError(
message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})",
)

def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
return {action: {"_index": index, "_id": document_id}}
@classmethod
def install_default_serializer(cls, client: AsyncElasticsearch) -> None:
cls.install_serializer(client=client)
client.transport.serializers.default_serializer = cls()

@classmethod
def install_serializer(cls, client: AsyncElasticsearch) -> None:
client.transport.serializers.serializers.update(
{
cls.mimetype: cls(),
cls.compatibility_mimetype: cls(),
}
)


class LessCapableNdjsonSerializer(NdjsonSerializer):
"""A NDJSON Serializer that doesnt try to be smart with datetime, floats, etc."""

mimetype: ClassVar[str] = "application/x-ndjson"
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+x-ndjson"

def default(self, data: Any) -> Any:
return LessCapableJsonSerializer.default(self=self, data=data) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType]

@classmethod
def install_serializer(cls, client: AsyncElasticsearch) -> None:
client.transport.serializers.serializers.update(
{
cls.mimetype: cls(),
cls.compatibility_mimetype: cls(),
}
)
Loading