Skip to content

Commit 2bb565d

Browse files
committed
rename redis file
1 parent 5c488b5 commit 2bb565d

File tree

1 file changed

+297
-0
lines changed

1 file changed

+297
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import asyncio
17+
import bisect
18+
import logging
19+
import time
20+
import uuid
21+
from typing import Any, Optional
22+
23+
import orjson
24+
import redis.asyncio as redis
25+
from redis.crc import key_slot
26+
from typing_extensions import override
27+
28+
from google.adk.events.event import Event
29+
from google.adk.sessions.base_session_service import (
30+
BaseSessionService,
31+
GetSessionConfig,
32+
ListSessionsResponse,
33+
)
34+
from google.adk.sessions.session import Session
35+
from google.adk.sessions.state import State
36+
37+
from .utils import _json_serializer
38+
39+
logger = logging.getLogger("google_adk." + __name__)
40+
41+
DEFAULT_EXPIRATION = 60 * 60 # 1 hour
42+
43+
44+
def _session_serializer(obj: Session) -> bytes:
45+
"""Serialize ADK Session to JSON bytes."""
46+
return orjson.dumps(obj.model_dump(), default=_json_serializer)
47+
48+
49+
class RedisKeys:
50+
"""Helper to generate Redis keys consistently."""
51+
52+
@staticmethod
53+
def session(session_id: str) -> str:
54+
return f"session:{session_id}"
55+
56+
@staticmethod
57+
def user_sessions(app_name: str, user_id: str) -> str:
58+
return f"{State.APP_PREFIX}:{app_name}:{user_id}"
59+
60+
@staticmethod
61+
def app_state(app_name: str) -> str:
62+
return f"{State.APP_PREFIX}{app_name}"
63+
64+
@staticmethod
65+
def user_state(app_name: str, user_id: str) -> str:
66+
return f"{State.USER_PREFIX}{app_name}:{user_id}"
67+
68+
69+
class RedisSessionService(BaseSessionService):
70+
"""A Redis-backed implementation of the session service."""
71+
72+
def __init__(
73+
self,
74+
host="localhost",
75+
port=6379,
76+
db=0,
77+
uri=None,
78+
cluster_uri=None,
79+
expire=DEFAULT_EXPIRATION,
80+
**kwargs,
81+
):
82+
self.expire = expire
83+
84+
if cluster_uri:
85+
self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs)
86+
elif uri:
87+
self.cache = redis.Redis.from_url(uri, **kwargs)
88+
else:
89+
self.cache = redis.Redis(host=host, port=port, db=db, **kwargs)
90+
91+
async def health_check(self) -> bool:
92+
try:
93+
await self.cache.ping()
94+
return True
95+
except redis.RedisError:
96+
return False
97+
98+
@override
99+
async def create_session(
100+
self,
101+
*,
102+
app_name: str,
103+
user_id: str,
104+
state: Optional[dict[str, Any]] = None,
105+
session_id: Optional[str] = None,
106+
) -> Session:
107+
session_id = (
108+
session_id.strip()
109+
if session_id and session_id.strip()
110+
else str(uuid.uuid4())
111+
)
112+
session = Session(
113+
app_name=app_name,
114+
user_id=user_id,
115+
id=session_id,
116+
state=state or {},
117+
last_update_time=time.time(),
118+
)
119+
120+
user_sessions_key = RedisKeys.user_sessions(app_name, user_id)
121+
session_key = RedisKeys.session(session_id)
122+
123+
async with self.cache.pipeline(transaction=False) as pipe:
124+
pipe.sadd(user_sessions_key, session_id)
125+
pipe.expire(user_sessions_key, self.expire)
126+
pipe.set(
127+
session_key,
128+
_session_serializer(session),
129+
ex=self.expire,
130+
)
131+
await pipe.execute()
132+
133+
return await self._merge_state(app_name, user_id, session)
134+
135+
@override
136+
async def get_session(
137+
self,
138+
*,
139+
app_name: str,
140+
user_id: str,
141+
session_id: str,
142+
config: Optional[GetSessionConfig] = None,
143+
) -> Optional[Session]:
144+
session_key = RedisKeys.session(session_id)
145+
raw_session = await self.cache.get(session_key)
146+
if not raw_session:
147+
user_sessions_key = RedisKeys.user_sessions(app_name, user_id)
148+
await self.cache.srem(user_sessions_key, session_id)
149+
return None
150+
151+
try:
152+
session_dict = orjson.loads(raw_session)
153+
session = Session.model_validate(session_dict)
154+
except (orjson.JSONDecodeError, Exception) as e:
155+
logger.error(f"Error decoding session {session_id}: {e}")
156+
return None
157+
158+
if config:
159+
if config.num_recent_events:
160+
session.events = session.events[-config.num_recent_events :]
161+
if config.after_timestamp:
162+
timestamps = [e.timestamp for e in session.events]
163+
start_index = bisect.bisect_left(timestamps, config.after_timestamp)
164+
session.events = session.events[start_index:]
165+
166+
return await self._merge_state(app_name, user_id, session)
167+
168+
@override
169+
async def list_sessions(
170+
self, *, app_name: str, user_id: str
171+
) -> ListSessionsResponse:
172+
sessions = await self._load_sessions(app_name, user_id)
173+
sessions_without_events = []
174+
175+
for session_data in sessions.values():
176+
session = Session.model_validate(session_data)
177+
session.events = []
178+
session.state = {}
179+
sessions_without_events.append(session)
180+
181+
return ListSessionsResponse(sessions=sessions_without_events)
182+
183+
@override
184+
async def delete_session(
185+
self, *, app_name: str, user_id: str, session_id: str
186+
) -> None:
187+
user_sessions_key = RedisKeys.user_sessions(app_name, user_id)
188+
session_key = RedisKeys.session(session_id)
189+
190+
async with self.cache.pipeline(transaction=False) as pipe:
191+
pipe.srem(user_sessions_key, session_id)
192+
pipe.delete(session_key)
193+
await pipe.execute()
194+
195+
@override
196+
async def append_event(self, session: Session, event: Event) -> Event:
197+
await super().append_event(session=session, event=event)
198+
session.last_update_time = event.timestamp
199+
200+
async with self.cache.pipeline(transaction=False) as pipe:
201+
user_sessions_key = RedisKeys.user_sessions(
202+
session.app_name, session.user_id
203+
)
204+
pipe.expire(user_sessions_key, self.expire)
205+
206+
if event.actions and event.actions.state_delta:
207+
for key, value in event.actions.state_delta.items():
208+
if key.startswith(State.APP_PREFIX):
209+
pipe.hset(
210+
RedisKeys.app_state(session.app_name),
211+
key.removeprefix(State.APP_PREFIX),
212+
orjson.dumps(value),
213+
)
214+
if key.startswith(State.USER_PREFIX):
215+
pipe.hset(
216+
RedisKeys.user_state(session.app_name, session.user_id),
217+
key.removeprefix(State.USER_PREFIX),
218+
orjson.dumps(value),
219+
)
220+
221+
pipe.set(
222+
RedisKeys.session(session.id),
223+
_session_serializer(session),
224+
ex=self.expire,
225+
)
226+
await pipe.execute()
227+
228+
return event
229+
230+
async def _merge_state(
231+
self, app_name: str, user_id: str, session: Session
232+
) -> Session:
233+
app_state = await self.cache.hgetall(RedisKeys.app_state(app_name))
234+
for k, v in app_state.items():
235+
session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v)
236+
237+
user_state = await self.cache.hgetall(RedisKeys.user_state(app_name, user_id))
238+
for k, v in user_state.items():
239+
session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v)
240+
241+
return session
242+
243+
async def _load_sessions(self, app_name: str, user_id: str) -> dict[str, dict]:
244+
key = RedisKeys.user_sessions(app_name, user_id)
245+
try:
246+
session_ids_bytes = await self.cache.smembers(key)
247+
if not session_ids_bytes:
248+
return {}
249+
250+
session_ids = [s.decode() for s in session_ids_bytes]
251+
session_keys = [RedisKeys.session(sid) for sid in session_ids]
252+
253+
# Group by slot for Redis Cluster
254+
slot_groups: dict[int, list[str]] = {}
255+
for k in session_keys:
256+
slot = key_slot(k.encode())
257+
slot_groups.setdefault(slot, []).append(k)
258+
259+
async def fetch_group(keys: list[str]):
260+
async with self.cache.pipeline(transaction=False) as pipe:
261+
for k in keys:
262+
pipe.get(k)
263+
return await pipe.execute()
264+
265+
results_per_group = await asyncio.gather(
266+
*(fetch_group(keys) for keys in slot_groups.values())
267+
)
268+
269+
raw_sessions = []
270+
for group_keys, group_results in zip(
271+
slot_groups.values(), results_per_group
272+
):
273+
raw_sessions.extend(zip(group_keys, group_results))
274+
275+
sessions = {}
276+
sessions_to_cleanup = []
277+
for key_name, raw_session in raw_sessions:
278+
session_id = key_name.split(":", 1)[1]
279+
if raw_session:
280+
try:
281+
sessions[session_id] = orjson.loads(raw_session)
282+
except orjson.JSONDecodeError as e:
283+
logger.error(f"Error decoding session {session_id}: {e}")
284+
else:
285+
logger.warning(
286+
"Session ID %s found in user set but session data is missing. Cleaning up.",
287+
session_id,
288+
)
289+
sessions_to_cleanup.append(session_id)
290+
291+
if sessions_to_cleanup:
292+
await self.cache.srem(key, *sessions_to_cleanup)
293+
294+
return sessions
295+
except redis.RedisError as e:
296+
logger.error(f"Error loading sessions for {user_id}: {e}")
297+
return {}

0 commit comments

Comments
 (0)