Skip to content

Handle missing config keys during checkpoint search #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ on:
branches:
- main

schedule:
- cron: "0 2 * * *" # 2 AM UTC nightly

workflow_dispatch:


env:
POETRY_VERSION: "1.8.3"

Expand Down
139 changes: 91 additions & 48 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver
from langgraph.checkpoint.redis.util import (
EMPTY_ID_SENTINEL,
from_storage_safe_id,
from_storage_safe_str,
to_storage_safe_id,
to_storage_safe_str,
)
from langgraph.checkpoint.redis.version import __lib_name__, __version__


Expand Down Expand Up @@ -79,12 +86,21 @@ def list(
filter_expression = []
if config:
filter_expression.append(
Tag("thread_id") == config["configurable"]["thread_id"]
Tag("thread_id")
== to_storage_safe_id(config["configurable"]["thread_id"])
)

# Reproducing the logic from the Postgres implementation, we'll
# search for checkpoints with any namespace, including an empty
# string, while `checkpoint_id` has to have a value.
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
filter_expression.append(Tag("checkpoint_ns") == checkpoint_ns)
filter_expression.append(
Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)
)
if checkpoint_id := get_checkpoint_id(config):
filter_expression.append(Tag("checkpoint_id") == checkpoint_id)
filter_expression.append(
Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)
)

if filter:
for k, v in filter.items():
Expand Down Expand Up @@ -122,9 +138,10 @@ def list(

# Process the results
for doc in results.docs:
thread_id = str(getattr(doc, "thread_id", ""))
checkpoint_ns = str(getattr(doc, "checkpoint_ns", ""))
checkpoint_id = str(getattr(doc, "checkpoint_id", ""))
thread_id = from_storage_safe_id(doc["thread_id"])
checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])

# Fetch channel_values
channel_values = self.get_channel_values(
Expand All @@ -135,11 +152,11 @@ def list(

# Fetch pending_sends from parent checkpoint
pending_sends = []
if doc["parent_checkpoint_id"]:
if parent_checkpoint_id:
pending_sends = self._load_pending_sends(
thread_id=thread_id,
checkpoint_ns=checkpoint_ns,
parent_checkpoint_id=doc["parent_checkpoint_id"],
parent_checkpoint_id=parent_checkpoint_id,
)

# Fetch and parse metadata
Expand All @@ -163,7 +180,7 @@ def list(
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": doc["checkpoint_id"],
"checkpoint_id": checkpoint_id,
}
}

Expand Down Expand Up @@ -194,49 +211,60 @@ def put(
) -> RunnableConfig:
"""Store a checkpoint to Redis."""
configurable = config["configurable"].copy()

thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
checkpoint_id = checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", "")
)

# For values we store in Redis, we need to convert empty strings to the
# sentinel value.
storage_safe_thread_id = to_storage_safe_id(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

copy = checkpoint.copy()
# When we return the config, we need to preserve empty strings that
# were passed in, instead of the sentinel value.
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
"checkpoint_id": checkpoint_id,
}
}

# Store checkpoint data
# Store checkpoint data.
checkpoint_data = {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
"parent_checkpoint_id": checkpoint_id,
"thread_id": storage_safe_thread_id,
"checkpoint_ns": storage_safe_checkpoint_ns,
"checkpoint_id": storage_safe_checkpoint_id,
"parent_checkpoint_id": storage_safe_checkpoint_id,
"checkpoint": self._dump_checkpoint(copy),
"metadata": self._dump_metadata(metadata),
}

# store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"]
checkpoint_data["step"] = metadata["step"] # type: ignore

self.checkpoints_index.load(
[checkpoint_data],
keys=[
BaseRedisSaver._make_redis_checkpoint_key(
thread_id, checkpoint_ns, checkpoint["id"]
storage_safe_thread_id,
storage_safe_checkpoint_ns,
storage_safe_checkpoint_id,
)
],
)

# Store blob values
# Store blob values.
blobs = self._dump_blobs(
thread_id,
checkpoint_ns,
storage_safe_thread_id,
storage_safe_checkpoint_ns,
copy.get("channel_values", {}),
new_versions,
)
Expand All @@ -258,19 +286,22 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = str(get_checkpoint_id(config))
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")

if checkpoint_id:
ascending = True

if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL:
checkpoint_filter_expression = (
(Tag("thread_id") == thread_id)
& (Tag("checkpoint_ns") == checkpoint_ns)
& (Tag("checkpoint_id") == checkpoint_id)
(Tag("thread_id") == to_storage_safe_id(thread_id))
& (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id))
)
else:
checkpoint_filter_expression = (Tag("thread_id") == thread_id) & (
Tag("checkpoint_ns") == checkpoint_ns
)
checkpoint_filter_expression = (
Tag("thread_id") == to_storage_safe_id(thread_id)
) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
ascending = False

# Construct the query
checkpoints_query = FilterQuery(
Expand All @@ -285,29 +316,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
],
num_results=1,
)
checkpoints_query.sort_by("checkpoint_id", asc=False)
checkpoints_query.sort_by("checkpoint_id", asc=ascending)

# Execute the query
results = self.checkpoints_index.search(checkpoints_query)
if not results.docs:
return None

doc = results.docs[0]
doc_thread_id = from_storage_safe_id(doc["thread_id"])
doc_checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])

# Fetch channel_values
channel_values = self.get_channel_values(
thread_id=doc["thread_id"],
checkpoint_ns=doc["checkpoint_ns"],
checkpoint_id=doc["checkpoint_id"],
thread_id=doc_thread_id,
checkpoint_ns=doc_checkpoint_ns,
checkpoint_id=doc_checkpoint_id,
)

# Fetch pending_sends from parent checkpoint
pending_sends = []
if doc["parent_checkpoint_id"]:
if doc_parent_checkpoint_id:
pending_sends = self._load_pending_sends(
thread_id=thread_id,
checkpoint_ns=checkpoint_ns,
parent_checkpoint_id=doc["parent_checkpoint_id"],
thread_id=doc_thread_id,
checkpoint_ns=doc_checkpoint_ns,
parent_checkpoint_id=doc_parent_checkpoint_id,
)

# Fetch and parse metadata
Expand All @@ -329,7 +364,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": doc["checkpoint_id"],
"checkpoint_id": doc_checkpoint_id,
}
}

Expand All @@ -340,7 +375,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
)

pending_writes = self._load_pending_writes(
thread_id, checkpoint_ns, checkpoint_id
thread_id, checkpoint_ns, doc_checkpoint_id
)

return CheckpointTuple(
Expand Down Expand Up @@ -379,10 +414,14 @@ def get_channel_values(
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
) -> dict[str, Any]:
"""Retrieve channel_values dictionary with properly constructed message objects."""
storage_safe_thread_id = to_storage_safe_id(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

checkpoint_query = FilterQuery(
filter_expression=(Tag("thread_id") == thread_id)
& (Tag("checkpoint_ns") == checkpoint_ns)
& (Tag("checkpoint_id") == checkpoint_id),
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
& (Tag("checkpoint_id") == storage_safe_checkpoint_id),
return_fields=["$.checkpoint.channel_versions"],
num_results=1,
)
Expand All @@ -400,8 +439,8 @@ def get_channel_values(
channel_values = {}
for channel, version in channel_versions.items():
blob_query = FilterQuery(
filter_expression=(Tag("thread_id") == thread_id)
& (Tag("checkpoint_ns") == checkpoint_ns)
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
& (Tag("channel") == channel)
& (Tag("version") == version),
return_fields=["type", "$.blob"],
Expand Down Expand Up @@ -437,11 +476,15 @@ def _load_pending_sends(
Returns:
List of (type, blob) tuples representing pending sends
"""
storage_safe_thread_id = to_storage_safe_str(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
storage_safe_parent_checkpoint_id = to_storage_safe_str(parent_checkpoint_id)

# Query checkpoint_writes for parent checkpoint's TASKS channel
parent_writes_query = FilterQuery(
filter_expression=(Tag("thread_id") == thread_id)
& (Tag("checkpoint_ns") == checkpoint_ns)
& (Tag("checkpoint_id") == parent_checkpoint_id)
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
& (Tag("checkpoint_id") == storage_safe_parent_checkpoint_id)
& (Tag("channel") == TASKS),
return_fields=["type", "blob", "task_path", "task_id", "idx"],
num_results=100, # Adjust as needed
Expand Down
Loading