Skip to content

Commit fbe9ecc

Browse files
wuliang229copybara-github
authored andcommitted
fix: race condition in table creation for DatabaseSessionService
Using one lock and checking for tables creation instead of schema version. Closes issue #4445 Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 869808097
1 parent 186371f commit fbe9ecc

File tree

2 files changed

+144
-51
lines changed

2 files changed

+144
-51
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,6 @@ def __init__(self, db_url: str, **kwargs: Any):
187187
# The current database schema version in use, "None" if not yet checked
188188
self._db_schema_version: Optional[str] = None
189189

190-
# Lock to ensure thread-safe schema version check
191-
self._db_schema_lock = asyncio.Lock()
192-
193190
# Per-session locks used to serialize append_event calls in this process.
194191
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
195192
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
@@ -261,62 +258,55 @@ async def _prepare_tables(self):
261258
DB schema version to use and creates the tables (including setting the
262259
schema version metadata) if needed.
263260
"""
264-
# Check the database schema version and set the _db_schema_version if
265-
# needed
266-
if self._db_schema_version is not None:
267-
return
268-
269-
async with self._db_schema_lock:
270-
# Double-check after acquiring the lock
271-
if self._db_schema_version is not None:
272-
return
273-
try:
274-
async with self.db_engine.connect() as conn:
275-
self._db_schema_version = await conn.run_sync(
276-
_schema_check_utils.get_db_schema_version_from_connection
277-
)
278-
except Exception as e:
279-
logger.error("Failed to inspect database tables: %s", e)
280-
raise
281-
282-
# Check if tables are created and create them if not
261+
# Early return if tables are already created
283262
if self._tables_created:
284263
return
285264

286265
async with self._table_creation_lock:
287266
# Double-check after acquiring the lock
288-
if not self._tables_created:
289-
async with self.db_engine.begin() as conn:
290-
if (
291-
self._db_schema_version
292-
== _schema_check_utils.LATEST_SCHEMA_VERSION
293-
):
294-
# Uncomment to recreate DB every time
295-
# await conn.run_sync(BaseV1.metadata.drop_all)
296-
logger.debug("Using V1 schema tables...")
297-
await conn.run_sync(BaseV1.metadata.create_all)
298-
else:
299-
# await conn.run_sync(BaseV0.metadata.drop_all)
300-
logger.debug("Using V0 schema tables...")
301-
await conn.run_sync(BaseV0.metadata.create_all)
302-
self._tables_created = True
267+
if self._tables_created:
268+
return
269+
270+
# Check the database schema version and set the _db_schema_version
271+
if self._db_schema_version is None:
272+
try:
273+
async with self.db_engine.connect() as conn:
274+
self._db_schema_version = await conn.run_sync(
275+
_schema_check_utils.get_db_schema_version_from_connection
276+
)
277+
except Exception as e:
278+
logger.error("Failed to inspect database tables: %s", e)
279+
raise
303280

281+
async with self.db_engine.begin() as conn:
304282
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
305-
async with self._rollback_on_exception_session() as sql_session:
306-
# Check if schema version is set, if not, set it to the latest
307-
# version
308-
stmt = select(StorageMetadata).where(
309-
StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
283+
# Uncomment to recreate DB every time
284+
# await conn.run_sync(BaseV1.metadata.drop_all)
285+
logger.debug("Using V1 schema tables...")
286+
await conn.run_sync(BaseV1.metadata.create_all)
287+
else:
288+
# await conn.run_sync(BaseV0.metadata.drop_all)
289+
logger.debug("Using V0 schema tables...")
290+
await conn.run_sync(BaseV0.metadata.create_all)
291+
292+
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
293+
async with self._rollback_on_exception_session() as sql_session:
294+
# Check if schema version is set, if not, set it to the latest
295+
# version
296+
stmt = select(StorageMetadata).where(
297+
StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
298+
)
299+
result = await sql_session.execute(stmt)
300+
metadata = result.scalars().first()
301+
if not metadata:
302+
metadata = StorageMetadata(
303+
key=_schema_check_utils.SCHEMA_VERSION_KEY,
304+
value=_schema_check_utils.LATEST_SCHEMA_VERSION,
310305
)
311-
result = await sql_session.execute(stmt)
312-
metadata = result.scalars().first()
313-
if not metadata:
314-
metadata = StorageMetadata(
315-
key=_schema_check_utils.SCHEMA_VERSION_KEY,
316-
value=_schema_check_utils.LATEST_SCHEMA_VERSION,
317-
)
318-
sql_session.add(metadata)
319-
await sql_session.commit()
306+
sql_session.add(metadata)
307+
await sql_session.commit()
308+
309+
self._tables_created = True
320310

321311
@override
322312
async def create_session(

tests/unittests/sessions/test_session_service.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,3 +1050,106 @@ def _spy_factory():
10501050
assert session.id == 'recovered'
10511051
finally:
10521052
await service.close()
1053+
1054+
1055+
@pytest.mark.asyncio
1056+
async def test_concurrent_prepare_tables_no_race_condition():
1057+
"""Verifies that concurrent calls to _prepare_tables wait for table creation.
1058+
Reproduces the race condition from
1059+
https://github.com/google/adk-python/issues/4445: when concurrent requests
1060+
arrive at startup, _prepare_tables must not return before tables exist.
1061+
Previously, the early-return guard checked _db_schema_version (set during
1062+
schema detection) instead of _tables_created, so a second request could
1063+
slip through after schema detection but before table creation finished.
1064+
"""
1065+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1066+
try:
1067+
# Tables haven't been created yet.
1068+
assert not service._tables_created
1069+
assert service._db_schema_version is None
1070+
1071+
# Launch several concurrent create_session calls, each with a unique
1072+
# app_name to avoid IntegrityError on the shared app_states row.
1073+
# Each will call _prepare_tables internally. If the race condition
1074+
# exists, some of these will fail because the "sessions" table doesn't
1075+
# exist yet.
1076+
num_concurrent = 5
1077+
results = await asyncio.gather(
1078+
*[
1079+
service.create_session(
1080+
app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}'
1081+
)
1082+
for i in range(num_concurrent)
1083+
],
1084+
return_exceptions=True,
1085+
)
1086+
1087+
# Every call must succeed – no exceptions allowed.
1088+
for i, result in enumerate(results):
1089+
assert not isinstance(result, BaseException), (
1090+
f'Concurrent create_session #{i} raised {result!r}; tables were'
1091+
' likely not ready due to the _prepare_tables race condition.'
1092+
)
1093+
1094+
# All sessions should be retrievable.
1095+
for i in range(num_concurrent):
1096+
session = await service.get_session(
1097+
app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}'
1098+
)
1099+
assert session is not None, f'Session sess_{i} not found after creation.'
1100+
1101+
assert service._tables_created
1102+
finally:
1103+
await service.close()
1104+
1105+
1106+
@pytest.mark.asyncio
1107+
async def test_prepare_tables_serializes_schema_detection_and_creation():
1108+
"""Verifies schema detection and table creation happen atomically under one
1109+
lock, so concurrent callers cannot observe a partially-initialized state.
1110+
After _prepare_tables completes, both _db_schema_version and _tables_created
1111+
must be set.
1112+
"""
1113+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1114+
try:
1115+
assert not service._tables_created
1116+
assert service._db_schema_version is None
1117+
1118+
await service._prepare_tables()
1119+
1120+
# Both must be set after a single _prepare_tables call.
1121+
assert service._tables_created
1122+
assert service._db_schema_version is not None
1123+
1124+
# Verify tables actually exist by performing a real operation.
1125+
session = await service.create_session(
1126+
app_name='app', user_id='user', session_id='s1'
1127+
)
1128+
assert session is not None
1129+
assert session.id == 's1'
1130+
finally:
1131+
await service.close()
1132+
1133+
1134+
@pytest.mark.asyncio
1135+
async def test_prepare_tables_idempotent_after_creation():
1136+
"""Calling _prepare_tables multiple times is safe and idempotent.
1137+
After tables are created, subsequent calls should return immediately via
1138+
the fast path without errors.
1139+
"""
1140+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1141+
try:
1142+
await service._prepare_tables()
1143+
assert service._tables_created
1144+
1145+
# Call again — should be a no-op via the fast path.
1146+
await service._prepare_tables()
1147+
assert service._tables_created
1148+
1149+
# Service should still work.
1150+
session = await service.create_session(
1151+
app_name='app', user_id='user', session_id='s1'
1152+
)
1153+
assert session.id == 's1'
1154+
finally:
1155+
await service.close()

0 commit comments

Comments
 (0)