Skip to content

release: Bump to 0.0.8 and loosen redisvl dep #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python-version: [3.9, '3.10', 3.11, 3.12, 3.13]
redis-version: ['6.2.6-v9', 'latest', '8.0-M03']
redis-version: ['6.2.6-v9', 'latest', '8.0.2']

steps:
- name: Check out repository
Expand All @@ -49,7 +49,7 @@ jobs:

- name: Set Redis image name
run: |
if [[ "${{ matrix.redis-version }}" == "8.0-M03" ]]; then
if [[ "${{ matrix.redis-version }}" == "8.0.2" ]]; then
echo "REDIS_IMAGE=redis:${{ matrix.redis-version }}" >> $GITHUB_ENV
else
echo "REDIS_IMAGE=redis/redis-stack-server:${{ matrix.redis-version }}" >> $GITHUB_ENV
Expand Down
2 changes: 1 addition & 1 deletion langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def put(
# store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"] # type: ignore
checkpoint_data["step"] = metadata["step"]

# Create the checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
Expand Down
74 changes: 28 additions & 46 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import os
from contextlib import asynccontextmanager
from functools import partial
from types import TracebackType
from typing import (
Any,
Expand All @@ -34,12 +33,10 @@
)
from langgraph.constants import TASKS
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.client import Pipeline
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory

from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.util import (
Expand All @@ -54,25 +51,6 @@
logger = logging.getLogger(__name__)


async def _write_obj_tx(
pipe: Pipeline,
key: str,
write_obj: Dict[str, Any],
upsert_case: bool,
) -> None:
exists: int = await pipe.exists(key)
if upsert_case:
if exists:
await pipe.json().set(key, "$.channel", write_obj["channel"])
await pipe.json().set(key, "$.type", write_obj["type"])
await pipe.json().set(key, "$.blob", write_obj["blob"])
else:
await pipe.json().set(key, "$", write_obj)
else:
if not exists:
await pipe.json().set(key, "$", write_obj)


class AsyncRedisSaver(
BaseRedisSaver[Union[AsyncRedis, AsyncRedisCluster], AsyncSearchIndex]
):
Expand Down Expand Up @@ -568,7 +546,7 @@ async def aput(
# store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"] # type: ignore
checkpoint_data["step"] = metadata["step"]

# Prepare checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
Expand All @@ -587,11 +565,11 @@ async def aput(

if self.cluster_mode:
# For cluster mode, execute operations individually
await self._redis.json().set(checkpoint_key, "$", checkpoint_data)
await self._redis.json().set(checkpoint_key, "$", checkpoint_data) # type: ignore[misc]

if blobs:
for key, data in blobs:
await self._redis.json().set(key, "$", data)
await self._redis.json().set(key, "$", data) # type: ignore[misc]

# Apply TTL if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
Expand All @@ -604,12 +582,12 @@ async def aput(
pipeline = self._redis.pipeline(transaction=True)

# Add checkpoint data to pipeline
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)

if blobs:
# Add all blob operations to the pipeline
for key, data in blobs:
await pipeline.json().set(key, "$", data)
pipeline.json().set(key, "$", data)

# Execute all operations atomically
await pipeline.execute()
Expand Down Expand Up @@ -654,13 +632,13 @@ async def aput(

if self.cluster_mode:
# For cluster mode, execute operation directly
await self._redis.json().set(
await self._redis.json().set( # type: ignore[misc]
checkpoint_key, "$", checkpoint_data
)
else:
# For non-cluster mode, use pipeline
pipeline = self._redis.pipeline(transaction=True)
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
await pipeline.execute()
except Exception:
# If this also fails, we just propagate the original cancellation
Expand Down Expand Up @@ -739,24 +717,18 @@ async def aput_writes(
exists = await self._redis.exists(key)
if exists:
# Update existing key
await self._redis.json().set(
key, "$.channel", write_obj["channel"]
)
await self._redis.json().set(
key, "$.type", write_obj["type"]
)
await self._redis.json().set(
key, "$.blob", write_obj["blob"]
)
await self._redis.json().set(key, "$.channel", write_obj["channel"]) # type: ignore[misc, arg-type]
await self._redis.json().set(key, "$.type", write_obj["type"]) # type: ignore[misc, arg-type]
await self._redis.json().set(key, "$.blob", write_obj["blob"]) # type: ignore[misc, arg-type]
else:
# Create new key
await self._redis.json().set(key, "$", write_obj)
await self._redis.json().set(key, "$", write_obj) # type: ignore[misc]
created_keys.append(key)
else:
# For non-upsert case, only set if key doesn't exist
exists = await self._redis.exists(key)
if not exists:
await self._redis.json().set(key, "$", write_obj)
await self._redis.json().set(key, "$", write_obj) # type: ignore[misc]
created_keys.append(key)

# Apply TTL to newly created keys
Expand Down Expand Up @@ -788,20 +760,30 @@ async def aput_writes(
exists = await self._redis.exists(key)
if exists:
# Update existing key
await pipeline.json().set(
key, "$.channel", write_obj["channel"]
pipeline.json().set(
key,
"$.channel",
write_obj["channel"], # type: ignore[arg-type]
)
pipeline.json().set(
key,
"$.type",
write_obj["type"], # type: ignore[arg-type]
)
pipeline.json().set(
key,
"$.blob",
write_obj["blob"], # type: ignore[arg-type]
)
await pipeline.json().set(key, "$.type", write_obj["type"])
await pipeline.json().set(key, "$.blob", write_obj["blob"])
else:
# Create new key
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)
else:
# For non-upsert case, only set if key doesn't exist
exists = await self._redis.exists(key)
if not exists:
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)

# Execute all operations atomically
Expand Down
35 changes: 12 additions & 23 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@
]


# func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None:
exists: int = await pipe.exists(key)
if exists:
await pipe.json().set(key, "$.channel", write_obj["channel"])
await pipe.json().set(key, "$.type", write_obj["type"])
await pipe.json().set(key, "$.blob", write_obj["blob"])
else:
await pipe.json().set(key, "$", write_obj)


class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
"""Async Redis implementation that only stores the most recent checkpoint."""

Expand Down Expand Up @@ -240,7 +229,7 @@ async def aput(
)

# Add checkpoint data to pipeline
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)

# Before storing the new blobs, clean up old ones that won't be needed
# - Get a list of all blob keys for this thread_id and checkpoint_ns
Expand Down Expand Up @@ -274,7 +263,7 @@ async def aput(
continue
else:
# This is an old version, delete it
await pipeline.delete(blob_key)
pipeline.delete(blob_key)

# Store the new blob values
blobs = self._dump_blobs(
Expand All @@ -287,7 +276,7 @@ async def aput(
if blobs:
# Add all blob data to pipeline
for key, data in blobs:
await pipeline.json().set(key, "$", data)
pipeline.json().set(key, "$", data)

# Execute all operations atomically
await pipeline.execute()
Expand Down Expand Up @@ -571,7 +560,7 @@ async def aput_writes(

# If the write is for a different checkpoint_id, delete it
if key_checkpoint_id != checkpoint_id:
await pipeline.delete(write_key)
pipeline.delete(write_key)

# Add new writes to the pipeline
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
Expand All @@ -589,17 +578,15 @@ async def aput_writes(
exists = await self._redis.exists(key)
if exists:
# Update existing key
await pipeline.json().set(
key, "$.channel", write_obj["channel"]
)
await pipeline.json().set(key, "$.type", write_obj["type"])
await pipeline.json().set(key, "$.blob", write_obj["blob"])
pipeline.json().set(key, "$.channel", write_obj["channel"])
pipeline.json().set(key, "$.type", write_obj["type"])
pipeline.json().set(key, "$.blob", write_obj["blob"])
else:
# Create new key
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
else:
# For shallow implementation, always set the full object
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)

# Execute all operations atomically
await pipeline.execute()
Expand Down Expand Up @@ -722,7 +709,9 @@ async def _aload_pending_writes(
(
parsed_key["task_id"],
parsed_key["idx"],
): await self._redis.json().get(key)
): await self._redis.json().get(
key
) # type: ignore[misc]
for key, parsed_key in sorted(
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
)
Expand Down
6 changes: 4 additions & 2 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
# NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing
return serialized_metadata.decode().replace("\\u0000", "")

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
def get_next_version( # type: ignore[override]
self, current: Optional[str], channel: ChannelProtocol[Any, Any, Any]
) -> str:
"""Generate next version number."""
if current is None:
current_v = 0
Expand Down Expand Up @@ -420,7 +422,7 @@ def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
return []

writes = []
for write in result["writes"]:
for write in result["writes"]: # type: ignore[call-overload]
writes.append(
(
write["task_id"],
Expand Down
14 changes: 5 additions & 9 deletions langgraph/store/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _batch_search_ops(
if not isinstance(store_doc, dict):
try:
store_doc = json.loads(
store_doc
store_doc # type: ignore[arg-type]
) # Attempt to parse if it's a JSON string
except (json.JSONDecodeError, TypeError):
logger.error(f"Failed to parse store_doc: {store_doc}")
Expand Down Expand Up @@ -578,16 +578,14 @@ def _batch_search_ops(
if self.cluster_mode:
for key in refresh_keys:
ttl = self._redis.ttl(key)
if ttl > 0: # type: ignore
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for key in refresh_keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if (
ttl > 0
): # Only refresh if key exists and has TTL # type: ignore
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
Expand Down Expand Up @@ -645,16 +643,14 @@ def _batch_search_ops(
if self.cluster_mode:
for key in refresh_keys:
ttl = self._redis.ttl(key)
if ttl > 0: # type: ignore
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for key in refresh_keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if (
ttl > 0
): # Only refresh if key exists and has TTL # type: ignore
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
Expand Down
4 changes: 2 additions & 2 deletions langgraph/store/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(

# Set up store configuration
self.index_config = index
self.ttl_config = ttl # type: ignore
self.ttl_config = ttl

if self.index_config:
self.index_config = self.index_config.copy()
Expand Down Expand Up @@ -744,7 +744,7 @@ async def _batch_search_ops(
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
result_map[store_key] = doc
# Fetch individually in cluster mode
store_doc_item = await self._redis.json().get(store_key)
store_doc_item = await self._redis.json().get(store_key) # type: ignore
store_docs.append(store_doc_item)
store_docs_raw = store_docs
else:
Expand Down
Loading