Skip to content

Commit 43f2b6b

Browse files
committed
refactor: align types
1 parent 323a485 commit 43f2b6b

File tree

3 files changed

+69
-25
lines changed

3 files changed

+69
-25
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from __future__ import annotations
22

33
import json
4-
from collections.abc import Iterator
54
from contextlib import contextmanager
6-
from typing import Any, List, Optional, Tuple, cast
5+
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast
76

87
from langchain_core.runnables import RunnableConfig
98
from langgraph.checkpoint.base import (
@@ -42,19 +41,21 @@ def __init__(
4241
redis_url: Optional[str] = None,
4342
*,
4443
redis_client: Optional[Redis] = None,
45-
connection_args: Optional[dict[str, Any]] = None,
44+
connection_args: Optional[Dict[str, Any]] = None,
45+
ttl: Optional[Dict[str, Any]] = None,
4646
) -> None:
4747
super().__init__(
4848
redis_url=redis_url,
4949
redis_client=redis_client,
5050
connection_args=connection_args,
51+
ttl=ttl,
5152
)
5253

5354
def configure_client(
5455
self,
5556
redis_url: Optional[str] = None,
5657
redis_client: Optional[Redis] = None,
57-
connection_args: Optional[dict[str, Any]] = None,
58+
connection_args: Optional[Dict[str, Any]] = None,
5859
) -> None:
5960
"""Configure the Redis client."""
6061
self._owns_its_client = redis_client is None
@@ -395,7 +396,8 @@ def from_conn_string(
395396
redis_url: Optional[str] = None,
396397
*,
397398
redis_client: Optional[Redis] = None,
398-
connection_args: Optional[dict[str, Any]] = None,
399+
connection_args: Optional[Dict[str, Any]] = None,
400+
ttl: Optional[Dict[str, Any]] = None,
399401
) -> Iterator[RedisSaver]:
400402
"""Create a new RedisSaver instance."""
401403
saver: Optional[RedisSaver] = None
@@ -404,6 +406,7 @@ def from_conn_string(
404406
redis_url=redis_url,
405407
redis_client=redis_client,
406408
connection_args=connection_args,
409+
ttl=ttl,
407410
)
408411

409412
yield saver
@@ -414,7 +417,7 @@ def from_conn_string(
414417

415418
def get_channel_values(
416419
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
417-
) -> dict[str, Any]:
420+
) -> Dict[str, Any]:
418421
"""Retrieve channel_values dictionary with properly constructed message objects."""
419422
storage_safe_thread_id = to_storage_safe_id(thread_id)
420423
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)

langgraph/checkpoint/redis/aio.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import asyncio
66
import json
77
import os
8-
from collections.abc import AsyncIterator
98
from contextlib import asynccontextmanager
109
from functools import partial
1110
from types import TracebackType
12-
from typing import Any, List, Optional, Sequence, Tuple, Type, cast
11+
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast
1312

1413
from langchain_core.runnables import RunnableConfig
1514
from langgraph.checkpoint.base import (
@@ -42,7 +41,7 @@
4241
async def _write_obj_tx(
4342
pipe: Pipeline,
4443
key: str,
45-
write_obj: dict[str, Any],
44+
write_obj: Dict[str, Any],
4645
upsert_case: bool,
4746
) -> None:
4847
exists: int = await pipe.exists(key)
@@ -73,20 +72,22 @@ def __init__(
7372
redis_url: Optional[str] = None,
7473
*,
7574
redis_client: Optional[AsyncRedis] = None,
76-
connection_args: Optional[dict[str, Any]] = None,
75+
connection_args: Optional[Dict[str, Any]] = None,
76+
ttl: Optional[Dict[str, Any]] = None,
7777
) -> None:
7878
super().__init__(
7979
redis_url=redis_url,
8080
redis_client=redis_client,
8181
connection_args=connection_args,
82+
ttl=ttl,
8283
)
8384
self.loop = asyncio.get_running_loop()
8485

8586
def configure_client(
8687
self,
8788
redis_url: Optional[str] = None,
8889
redis_client: Optional[AsyncRedis] = None,
89-
connection_args: Optional[dict[str, Any]] = None,
90+
connection_args: Optional[Dict[str, Any]] = None,
9091
) -> None:
9192
"""Configure the Redis client."""
9293
self._owns_its_client = redis_client is None
@@ -706,18 +707,20 @@ async def from_conn_string(
706707
redis_url: Optional[str] = None,
707708
*,
708709
redis_client: Optional[AsyncRedis] = None,
709-
connection_args: Optional[dict[str, Any]] = None,
710+
connection_args: Optional[Dict[str, Any]] = None,
711+
ttl: Optional[Dict[str, Any]] = None,
710712
) -> AsyncIterator[AsyncRedisSaver]:
711713
async with cls(
712714
redis_url=redis_url,
713715
redis_client=redis_client,
714716
connection_args=connection_args,
717+
ttl=ttl,
715718
) as saver:
716719
yield saver
717720

718721
async def aget_channel_values(
719722
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
720-
) -> dict[str, Any]:
723+
) -> Dict[str, Any]:
721724
"""Retrieve channel_values dictionary with properly constructed message objects."""
722725
storage_safe_thread_id = to_storage_safe_id(thread_id)
723726
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
@@ -767,7 +770,7 @@ async def aget_channel_values(
767770

768771
async def _aload_pending_sends(
769772
self, thread_id: str, checkpoint_ns: str = "", parent_checkpoint_id: str = ""
770-
) -> list[tuple[str, bytes]]:
773+
) -> List[Tuple[str, bytes]]:
771774
"""Load pending sends for a parent checkpoint.
772775
773776
Args:

langgraph/checkpoint/redis/base.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import json
44
import random
55
from abc import abstractmethod
6-
from collections.abc import Sequence
7-
from typing import Any, Generic, List, Optional, cast
6+
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, cast
87

98
from langchain_core.runnables import RunnableConfig
109
from langgraph.checkpoint.base import (
@@ -100,12 +99,16 @@ def __init__(
10099
redis_url: Optional[str] = None,
101100
*,
102101
redis_client: Optional[RedisClientType] = None,
103-
connection_args: Optional[dict[str, Any]] = None,
102+
connection_args: Optional[Dict[str, Any]] = None,
103+
ttl: Optional[Dict[str, Any]] = None,
104104
) -> None:
105105
super().__init__(serde=JsonPlusRedisSerializer())
106106
if redis_url is None and redis_client is None:
107107
raise ValueError("Either redis_url or redis_client must be provided")
108108

109+
# Store TTL configuration
110+
self.ttl_config = ttl
111+
109112
self.configure_client(
110113
redis_url=redis_url,
111114
redis_client=redis_client,
@@ -128,7 +131,7 @@ def configure_client(
128131
self,
129132
redis_url: Optional[str] = None,
130133
redis_client: Optional[RedisClientType] = None,
131-
connection_args: Optional[dict[str, Any]] = None,
134+
connection_args: Optional[Dict[str, Any]] = None,
132135
) -> None:
133136
"""Configure the Redis client."""
134137
pass
@@ -180,11 +183,46 @@ def setup(self) -> None:
180183
self.checkpoint_blobs_index.create(overwrite=False)
181184
self.checkpoint_writes_index.create(overwrite=False)
182185

186+
def _apply_ttl_to_keys(
187+
self,
188+
main_key: str,
189+
related_keys: Optional[List[str]] = None,
190+
ttl_minutes: Optional[float] = None,
191+
) -> Any:
192+
"""Apply Redis native TTL to keys.
193+
194+
Args:
195+
main_key: The primary Redis key
196+
related_keys: Additional Redis keys that should expire at the same time
197+
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
198+
199+
Returns:
200+
Result of the Redis operation
201+
"""
202+
if ttl_minutes is None:
203+
# Check if there's a default TTL in config
204+
if self.ttl_config and "default_ttl" in self.ttl_config:
205+
ttl_minutes = self.ttl_config.get("default_ttl")
206+
207+
if ttl_minutes is not None:
208+
ttl_seconds = int(ttl_minutes * 60)
209+
pipeline = self._redis.pipeline()
210+
211+
# Set TTL for main key
212+
pipeline.expire(main_key, ttl_seconds)
213+
214+
# Set TTL for related keys
215+
if related_keys:
216+
for key in related_keys:
217+
pipeline.expire(key, ttl_seconds)
218+
219+
return pipeline.execute()
220+
183221
def _load_checkpoint(
184222
self,
185-
checkpoint: dict[str, Any],
186-
channel_values: dict[str, Any],
187-
pending_sends: list[Any],
223+
checkpoint: Dict[str, Any],
224+
channel_values: Dict[str, Any],
225+
pending_sends: List[Any],
188226
) -> Checkpoint:
189227
if not checkpoint:
190228
return {}
@@ -218,7 +256,7 @@ def _load_blobs(self, blob_values: dict[str, Any]) -> dict[str, Any]:
218256
if v["type"] != "empty"
219257
}
220258

221-
def _get_type_and_blob(self, value: Any) -> tuple[str, Optional[bytes]]:
259+
def _get_type_and_blob(self, value: Any) -> Tuple[str, Optional[bytes]]:
222260
"""Helper to get type and blob from a value."""
223261
t, b = self.serde.dumps_typed(value)
224262
return t, b
@@ -227,9 +265,9 @@ def _dump_blobs(
227265
self,
228266
thread_id: str,
229267
checkpoint_ns: str,
230-
values: dict[str, Any],
268+
values: Dict[str, Any],
231269
versions: ChannelVersions,
232-
) -> list[tuple[str, dict[str, Any]]]:
270+
) -> List[Tuple[str, Dict[str, Any]]]:
233271
"""Convert blob data for Redis storage."""
234272
if not versions:
235273
return []
@@ -337,7 +375,7 @@ def _decode_blob(self, blob: str) -> bytes:
337375
# Handle both malformed base64 data and incorrect input types
338376
return blob.encode() if isinstance(blob, str) else blob
339377

340-
def _load_writes_from_redis(self, write_key: str) -> list[tuple[str, str, Any]]:
378+
def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
341379
"""Load writes from Redis JSON storage by key."""
342380
if not write_key:
343381
return []

0 commit comments

Comments
 (0)