Skip to content

Clean up some key and UUID->str conversion handling #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ def put(

thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", "")
thread_ts = configurable.pop("thread_ts", "")
checkpoint_id = (
configurable.pop("checkpoint_id", configurable.pop("thread_ts", ""))
or thread_ts
)

# For values we store in Redis, we need to convert empty strings to the
Expand Down
31 changes: 21 additions & 10 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import partial
from sys import thread_info
from types import TracebackType
from typing import Any, List, Optional, Sequence, Tuple, Type, cast

Expand Down Expand Up @@ -375,10 +374,20 @@ async def aput(
) -> RunnableConfig:
"""Store a checkpoint to Redis."""
configurable = config["configurable"].copy()

thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
thread_ts = configurable.pop("thread_ts", "")
checkpoint_id = configurable.pop("checkpoint_id", thread_ts) or thread_ts
checkpoint_id = (
configurable.pop("checkpoint_id", configurable.pop("thread_ts", ""))
or thread_ts
)

# For values we store in Redis, we need to convert empty strings to the
# sentinel value.
storage_safe_thread_id = to_storage_safe_id(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

copy = checkpoint.copy()
next_config = {
Expand All @@ -391,32 +400,34 @@ async def aput(

# Store checkpoint data
checkpoint_data = {
"thread_id": thread_id,
"checkpoint_ns": to_storage_safe_str(checkpoint_ns),
"checkpoint_id": to_storage_safe_id(checkpoint_id),
"parent_checkpoint_id": to_storage_safe_id(checkpoint_id),
"thread_id": storage_safe_thread_id,
"checkpoint_ns": storage_safe_checkpoint_ns,
"checkpoint_id": storage_safe_checkpoint_id,
"parent_checkpoint_id": storage_safe_checkpoint_id,
"checkpoint": self._dump_checkpoint(copy),
"metadata": self._dump_metadata(metadata),
}

# store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"]
checkpoint_data["step"] = metadata["step"] # type: ignore

await self.checkpoints_index.load(
[checkpoint_data],
keys=[
BaseRedisSaver._make_redis_checkpoint_key(
thread_id, checkpoint_ns, checkpoint_id
storage_safe_thread_id,
storage_safe_checkpoint_ns,
storage_safe_checkpoint_id,
)
],
)

# Store blob values
blobs = self._dump_blobs(
thread_id,
checkpoint_ns,
storage_safe_thread_id,
storage_safe_checkpoint_ns,
copy.get("channel_values", {}),
new_versions,
)
Expand Down
10 changes: 5 additions & 5 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ def _make_redis_checkpoint_key(
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_PREFIX,
to_storage_safe_id(thread_id),
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(checkpoint_id),
str(to_storage_safe_id(checkpoint_id)),
]
)

Expand All @@ -470,7 +470,7 @@ def _make_redis_checkpoint_blob_key(
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_BLOB_PREFIX,
to_storage_safe_str(thread_id),
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
channel,
version,
Expand All @@ -485,9 +485,9 @@ def _make_redis_checkpoint_writes_key(
task_id: str,
idx: Optional[int],
) -> str:
storage_safe_thread_id = to_storage_safe_str(thread_id)
storage_safe_thread_id = str(to_storage_safe_id(thread_id))
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
storage_safe_checkpoint_id = to_storage_safe_str(checkpoint_id)
storage_safe_checkpoint_id = str(to_storage_safe_id(checkpoint_id))

if idx is None:
return REDIS_KEY_SEPARATOR.join(
Expand Down
60 changes: 15 additions & 45 deletions langgraph/checkpoint/redis/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, Callable, Optional, TypeVar, Union

"""
RediSearch versions below 2.10 don't support indexing and querying
empty strings, so we use a sentinel value to represent empty strings.
Expand All @@ -8,14 +6,17 @@
sentinel values are from the first run of the graph, so this should
generally be correct.
"""

EMPTY_STRING_SENTINEL = "__empty__"
EMPTY_ID_SENTINEL = "00000000-0000-0000-0000-000000000000"


def to_storage_safe_str(value: str) -> str:
"""
Convert any empty string to an empty string sentinel if found,
otherwise return the value unchanged.
Prepare a value for storage in Redis as a string.

Convert an empty string to a sentinel value, otherwise return the
value as a string.

Args:
value (str): The value to convert.
Expand All @@ -26,13 +27,13 @@ def to_storage_safe_str(value: str) -> str:
if value == "":
return EMPTY_STRING_SENTINEL
else:
return value
return str(value)


def from_storage_safe_str(value: str) -> str:
"""
Convert a value from an empty string sentinel to an empty string
if found, otherwise return the value unchanged.
Convert a value from a sentinel value to an empty string if present,
otherwise return the value unchanged.

Args:
value (str): The value to convert.
Expand All @@ -48,8 +49,10 @@ def from_storage_safe_str(value: str) -> str:

def to_storage_safe_id(value: str) -> str:
"""
Convert any empty ID string to an empty ID sentinel if found,
otherwise return the value unchanged.
Prepare a value for storage in Redis as an ID.

Convert an empty string to a sentinel value for empty ID strings, otherwise
return the value as a string.

Args:
value (str): The value to convert.
Expand All @@ -60,13 +63,13 @@ def to_storage_safe_id(value: str) -> str:
if value == "":
return EMPTY_ID_SENTINEL
else:
return value
return str(value)


def from_storage_safe_id(value: str) -> str:
"""
Convert a value from an empty ID sentinel to an empty ID
if found, otherwise return the value unchanged.
Convert a value from a sentinel value for empty ID strings to an empty
ID string if present, otherwise return the value unchanged.

Args:
value (str): The value to convert.
Expand All @@ -78,36 +81,3 @@ def from_storage_safe_id(value: str) -> str:
return ""
else:
return value


def storage_safe_get(
doc: dict[str, Any], key: str, default: Any = None
) -> Optional[Any]:
"""
Get a value from a Redis document or dictionary, using a sentinel
value to represent empty strings.

If the sentinel value is found, it is converted back to an empty string.

Args:
doc (dict[str, Any]): The document to get the value from.
key (str): The key to get the value from.
default (Any): The default value to return if the key is not found.
Returns:
Optional[Any]: None if the key is not found, or else the value from
the document or dictionary, with empty strings converted
to the empty string sentinel and the sentinel converted
back to an empty string.
"""
try:
# NOTE: The Document class that comes back from `search()` support
# [key] access but not `get()` for some reason, so we use direct
# key access with an exception guard.
value = doc[key]
except KeyError:
value = None

if value is None:
return default

return to_storage_safe_str(value)