3
3
import json
4
4
import random
5
5
from abc import abstractmethod
6
- from collections .abc import Sequence
7
- from typing import Any , Generic , List , Optional , cast
6
+ from typing import Any , Dict , Generic , List , Optional , Sequence , Tuple , cast
8
7
9
8
from langchain_core .runnables import RunnableConfig
10
9
from langgraph .checkpoint .base import (
@@ -100,12 +99,16 @@ def __init__(
100
99
redis_url : Optional [str ] = None ,
101
100
* ,
102
101
redis_client : Optional [RedisClientType ] = None ,
103
- connection_args : Optional [dict [str , Any ]] = None ,
102
+ connection_args : Optional [Dict [str , Any ]] = None ,
103
+ ttl : Optional [Dict [str , Any ]] = None ,
104
104
) -> None :
105
105
super ().__init__ (serde = JsonPlusRedisSerializer ())
106
106
if redis_url is None and redis_client is None :
107
107
raise ValueError ("Either redis_url or redis_client must be provided" )
108
108
109
+ # Store TTL configuration
110
+ self .ttl_config = ttl
111
+
109
112
self .configure_client (
110
113
redis_url = redis_url ,
111
114
redis_client = redis_client ,
@@ -128,7 +131,7 @@ def configure_client(
128
131
self ,
129
132
redis_url : Optional [str ] = None ,
130
133
redis_client : Optional [RedisClientType ] = None ,
131
- connection_args : Optional [dict [str , Any ]] = None ,
134
+ connection_args : Optional [Dict [str , Any ]] = None ,
132
135
) -> None :
133
136
"""Configure the Redis client."""
134
137
pass
@@ -180,11 +183,46 @@ def setup(self) -> None:
180
183
self .checkpoint_blobs_index .create (overwrite = False )
181
184
self .checkpoint_writes_index .create (overwrite = False )
182
185
186
+ def _apply_ttl_to_keys (
187
+ self ,
188
+ main_key : str ,
189
+ related_keys : Optional [List [str ]] = None ,
190
+ ttl_minutes : Optional [float ] = None ,
191
+ ) -> Any :
192
+ """Apply Redis native TTL to keys.
193
+
194
+ Args:
195
+ main_key: The primary Redis key
196
+ related_keys: Additional Redis keys that should expire at the same time
197
+ ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
198
+
199
+ Returns:
200
+ Result of the Redis operation
201
+ """
202
+ if ttl_minutes is None :
203
+ # Check if there's a default TTL in config
204
+ if self .ttl_config and "default_ttl" in self .ttl_config :
205
+ ttl_minutes = self .ttl_config .get ("default_ttl" )
206
+
207
+ if ttl_minutes is not None :
208
+ ttl_seconds = int (ttl_minutes * 60 )
209
+ pipeline = self ._redis .pipeline ()
210
+
211
+ # Set TTL for main key
212
+ pipeline .expire (main_key , ttl_seconds )
213
+
214
+ # Set TTL for related keys
215
+ if related_keys :
216
+ for key in related_keys :
217
+ pipeline .expire (key , ttl_seconds )
218
+
219
+ return pipeline .execute ()
220
+
183
221
def _load_checkpoint (
184
222
self ,
185
- checkpoint : dict [str , Any ],
186
- channel_values : dict [str , Any ],
187
- pending_sends : list [Any ],
223
+ checkpoint : Dict [str , Any ],
224
+ channel_values : Dict [str , Any ],
225
+ pending_sends : List [Any ],
188
226
) -> Checkpoint :
189
227
if not checkpoint :
190
228
return {}
@@ -218,7 +256,7 @@ def _load_blobs(self, blob_values: dict[str, Any]) -> dict[str, Any]:
218
256
if v ["type" ] != "empty"
219
257
}
220
258
221
- def _get_type_and_blob (self , value : Any ) -> tuple [str , Optional [bytes ]]:
259
+ def _get_type_and_blob (self , value : Any ) -> Tuple [str , Optional [bytes ]]:
222
260
"""Helper to get type and blob from a value."""
223
261
t , b = self .serde .dumps_typed (value )
224
262
return t , b
@@ -227,9 +265,9 @@ def _dump_blobs(
227
265
self ,
228
266
thread_id : str ,
229
267
checkpoint_ns : str ,
230
- values : dict [str , Any ],
268
+ values : Dict [str , Any ],
231
269
versions : ChannelVersions ,
232
- ) -> list [ tuple [str , dict [str , Any ]]]:
270
+ ) -> List [ Tuple [str , Dict [str , Any ]]]:
233
271
"""Convert blob data for Redis storage."""
234
272
if not versions :
235
273
return []
@@ -337,7 +375,7 @@ def _decode_blob(self, blob: str) -> bytes:
337
375
# Handle both malformed base64 data and incorrect input types
338
376
return blob .encode () if isinstance (blob , str ) else blob
339
377
340
- def _load_writes_from_redis (self , write_key : str ) -> list [ tuple [str , str , Any ]]:
378
+ def _load_writes_from_redis (self , write_key : str ) -> List [ Tuple [str , str , Any ]]:
341
379
"""Load writes from Redis JSON storage by key."""
342
380
if not write_key :
343
381
return []
0 commit comments