Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions key-value/key-value-aio/src/key_value/aio/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from key_value.shared.errors import StoreSetupError
from key_value.shared.type_checking.bear_spray import bear_enforce
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitization import PassthroughStrategy, SanitizationStrategy
from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
from typing_extensions import Self, override
Expand Down Expand Up @@ -69,6 +70,8 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
_setup_collection_complete: defaultdict[str, bool]

_serialization_adapter: SerializationAdapter
_key_sanitization_strategy: SanitizationStrategy
_collection_sanitization_strategy: SanitizationStrategy

_seed: FROZEN_SEED_DATA_TYPE

Expand All @@ -78,13 +81,17 @@ def __init__(
self,
*,
serialization_adapter: SerializationAdapter | None = None,
key_sanitization_strategy: SanitizationStrategy | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
default_collection: str | None = None,
seed: SEED_DATA_TYPE | None = None,
) -> None:
"""Initialize the managed key-value store.

Args:
serialization_adapter: The serialization adapter to use for the store.
key_sanitization_strategy: The sanitization strategy to use for keys.
collection_sanitization_strategy: The sanitization strategy to use for collections.
default_collection: The default collection to use if no collection is provided.
Defaults to "default_collection".
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
Expand All @@ -103,6 +110,9 @@ def __init__(

self._serialization_adapter = serialization_adapter or BasicSerializationAdapter()

self._key_sanitization_strategy = key_sanitization_strategy or PassthroughStrategy()
self._collection_sanitization_strategy = collection_sanitization_strategy or PassthroughStrategy()

if not hasattr(self, "_stable_api"):
self._stable_api = False

Expand All @@ -117,6 +127,17 @@ async def _setup(self) -> None:
async def _setup_collection(self, *, collection: str) -> None:
"""Initialize the collection (called once before first use of the collection)."""

def _sanitize_collection_and_key(self, collection: str, key: str) -> tuple[str, str]:
return self._sanitize_collection(collection=collection), self._sanitize_key(key=key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Prefer explicit tuple parentheses for consistency.

The tuple return on line 130 omits explicit parentheses, while the sync version (from key-value-sync/src/key_value/sync/code_gen/stores/base.py) uses explicit parentheses. For consistency across async and sync implementations, consider adding them.

Apply this diff:

-        return self._sanitize_collection(collection=collection), self._sanitize_key(key=key)
+        return (self._sanitize_collection(collection=collection), self._sanitize_key(key=key))
🤖 Prompt for AI Agents
In key-value/key-value-aio/src/key_value/aio/stores/base.py around line 130, the
return statement returns a tuple without explicit parentheses; change it to
return the sanitized collection and key wrapped in explicit parentheses (e.g.,
return (self._sanitize_collection(collection=collection),
self._sanitize_key(key=key))) to match the sync implementation and maintain
consistency across modules.


def _sanitize_collection(self, collection: str) -> str:
self._collection_sanitization_strategy.validate(value=collection)
return self._collection_sanitization_strategy.sanitize(value=collection)

def _sanitize_key(self, key: str) -> str:
self._key_sanitization_strategy.validate(value=key)
return self._key_sanitization_strategy.sanitize(value=key)

async def _seed_store(self) -> None:
"""Seed the store with the data from the seed."""
for collection, items in self._seed.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from elastic_transport import SerializationError as ElasticsearchSerializationError
from key_value.shared.errors import DeserializationError, SerializationError
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitization import AlwaysHashStrategy, HashFragmentMode, HybridSanitizationStrategy
from key_value.shared.utils.sanitize import (
ALPHANUMERIC_CHARACTERS,
LOWERCASE_ALPHABET,
NUMBERS,
sanitize_string,
UPPERCASE_ALPHABET,
)
from key_value.shared.utils.serialization import SerializationAdapter
from key_value.shared.utils.time_to_live import now_as_epoch
Expand Down Expand Up @@ -145,7 +146,7 @@ class ElasticsearchStore(

_native_storage: bool

_adapter: SerializationAdapter
_serializer: SerializationAdapter

@overload
def __init__(
Expand Down Expand Up @@ -207,12 +208,31 @@ def __init__(
LessCapableJsonSerializer.install_default_serializer(client=self._client)
LessCapableNdjsonSerializer.install_serializer(client=self._client)

self._index_prefix = index_prefix
self._index_prefix = index_prefix.lower()
self._native_storage = native_storage
self._is_serverless = False
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)

super().__init__(default_collection=default_collection)
# We have 240 characters to work with
# We need to account for the index prefix and the hyphen.
max_index_length = MAX_INDEX_LENGTH - (len(self._index_prefix) + 1)

self._serializer = ElasticsearchSerializationAdapter(native_storage=native_storage)

# We allow uppercase through the sanitizer so we can lowercase them instead of them
# all turning into underscores.
collection_sanitization = HybridSanitizationStrategy(
replacement_character="_",
max_length=max_index_length,
allowed_characters=UPPERCASE_ALPHABET + ALLOWED_INDEX_CHARACTERS,
hash_fragment_mode=HashFragmentMode.ALWAYS,
)
key_sanitization = AlwaysHashStrategy()

super().__init__(
default_collection=default_collection,
collection_sanitization_strategy=collection_sanitization,
key_sanitization_strategy=key_sanitization,
)

@override
async def _setup(self) -> None:
Expand All @@ -222,32 +242,22 @@ async def _setup(self) -> None:

@override
async def _setup_collection(self, *, collection: str) -> None:
index_name = self._sanitize_index_name(collection=collection)
index_name = self._get_index_name(collection=collection)

if await self._client.options(ignore_status=404).indices.exists(index=index_name):
return

_ = await self._client.options(ignore_status=404).indices.create(index=index_name, mappings=DEFAULT_MAPPING, settings={})

def _sanitize_index_name(self, collection: str) -> str:
return sanitize_string(
value=self._index_prefix + "-" + collection,
replacement_character="_",
max_length=MAX_INDEX_LENGTH,
allowed_characters=ALLOWED_INDEX_CHARACTERS,
)
def _get_index_name(self, collection: str) -> str:
return self._index_prefix + "-" + self._sanitize_collection(collection=collection).lower()

def _sanitize_document_id(self, key: str) -> str:
return sanitize_string(
value=key,
replacement_character="_",
max_length=MAX_KEY_LENGTH,
allowed_characters=ALLOWED_KEY_CHARACTERS,
)
def _get_document_id(self, key: str) -> str:
return self._sanitize_key(key=key)

def _get_destination(self, *, collection: str, key: str) -> tuple[str, str]:
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)
index_name: str = self._get_index_name(collection=collection)
document_id: str = self._get_document_id(key=key)

return index_name, document_id

Expand All @@ -263,7 +273,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
return None

try:
return self._adapter.load_dict(data=source)
return self._serializer.load_dict(data=source)
except DeserializationError:
return None

Expand All @@ -273,8 +283,8 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
return []

# Use mget for efficient batch retrieval
index_name = self._sanitize_index_name(collection=collection)
document_ids = [self._sanitize_document_id(key=key) for key in keys]
index_name = self._get_index_name(collection=collection)
document_ids = [self._get_document_id(key=key) for key in keys]
docs = [{"_id": document_id} for document_id in document_ids]

elasticsearch_response = await self._client.options(ignore_status=404).mget(index=index_name, docs=docs)
Expand All @@ -296,7 +306,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
continue

try:
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
entries_by_id[doc_id] = self._serializer.load_dict(data=source)
except DeserializationError as e:
logger.error(
"Failed to deserialize Elasticsearch document in batch operation",
Expand Down Expand Up @@ -324,10 +334,10 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)
index_name: str = self._get_index_name(collection=collection)
document_id: str = self._get_document_id(key=key)

document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry)

try:
_ = await self._client.index(
Expand Down Expand Up @@ -358,14 +368,14 @@ async def _put_managed_entries(

operations: list[dict[str, Any]] = []

index_name: str = self._sanitize_index_name(collection=collection)
index_name: str = self._get_index_name(collection=collection)

for key, managed_entry in zip(keys, managed_entries, strict=True):
document_id: str = self._sanitize_document_id(key=key)
document_id: str = self._get_document_id(key=key)

index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)

document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry)

operations.extend([index_action, document])

Expand All @@ -379,8 +389,8 @@ async def _put_managed_entries(

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)
index_name: str = self._get_index_name(collection=collection)
document_id: str = self._get_document_id(key=key)

elasticsearch_response: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete(
index=index_name, id=document_id
Expand Down Expand Up @@ -428,7 +438,7 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
limit = min(limit or DEFAULT_PAGE_SIZE, PAGE_LIMIT)

result: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).search(
index=self._sanitize_index_name(collection=collection),
index=self._get_index_name(collection=collection),
fields=[{"key": None}],
body={
"query": {
Expand Down Expand Up @@ -483,7 +493,7 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
@override
async def _delete_collection(self, *, collection: str) -> bool:
result: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete_by_query(
index=self._sanitize_index_name(collection=collection),
index=self._get_index_name(collection=collection),
body={
"query": {
"term": {
Expand Down
33 changes: 13 additions & 20 deletions key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from key_value.shared.utils.compound import compound_key
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string
from key_value.shared.utils.sanitization import HybridSanitizationStrategy
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS
from typing_extensions import override

from key_value.aio.stores.base import BaseStore
Expand All @@ -15,11 +16,9 @@
raise ImportError(msg) from e

DEFAULT_KEYCHAIN_SERVICE = "py-key-value"
MAX_KEY_LENGTH = 256
ALLOWED_KEY_CHARACTERS: str = ALPHANUMERIC_CHARACTERS

MAX_COLLECTION_LENGTH = 256
ALLOWED_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
MAX_KEY_COLLECTION_LENGTH = 256
ALLOWED_KEY_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS


class KeyringStore(BaseStore):
Expand Down Expand Up @@ -48,25 +47,19 @@ def __init__(
"""
self._service_name = service_name

super().__init__(default_collection=default_collection)

def _sanitize_collection_name(self, collection: str) -> str:
return sanitize_string(
value=collection,
max_length=MAX_COLLECTION_LENGTH,
allowed_characters=ALLOWED_COLLECTION_CHARACTERS,
sanitization_strategy = HybridSanitizationStrategy(
replacement_character="_", max_length=MAX_KEY_COLLECTION_LENGTH, allowed_characters=ALLOWED_KEY_COLLECTION_CHARACTERS
)

def _sanitize_key(self, key: str) -> str:
return sanitize_string(
value=key,
max_length=MAX_KEY_LENGTH,
allowed_characters=ALLOWED_KEY_CHARACTERS,
super().__init__(
default_collection=default_collection,
collection_sanitization_strategy=sanitization_strategy,
key_sanitization_strategy=sanitization_strategy,
)

@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
sanitized_collection = self._sanitize_collection_name(collection=collection)
sanitized_collection = self._sanitize_collection(collection=collection)
sanitized_key = self._sanitize_key(key=key)

combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
Expand All @@ -83,7 +76,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry

@override
async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None:
sanitized_collection = self._sanitize_collection_name(collection=collection)
sanitized_collection = self._sanitize_collection(collection=collection)
sanitized_key = self._sanitize_key(key=key)

combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
Expand All @@ -94,7 +87,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
sanitized_collection = self._sanitize_collection_name(collection=collection)
sanitized_collection = self._sanitize_collection(collection=collection)
sanitized_key = self._sanitize_key(key=key)

combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from key_value.shared.utils.compound import compound_key
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitization import HashExcessLengthStrategy
from typing_extensions import override

from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseStore
Expand Down Expand Up @@ -46,7 +47,12 @@ def __init__(
"""
self._client = client or Client(host=host, port=port)

super().__init__(default_collection=default_collection)
sanitization_strategy = HashExcessLengthStrategy(max_length=MAX_KEY_LENGTH)

super().__init__(
default_collection=default_collection,
key_sanitization_strategy=sanitization_strategy,
)

def sanitize_key(self, key: str) -> str:
if len(key) > MAX_KEY_LENGTH:
Expand Down
Loading