Skip to content

Commit 2a60325

Browse files
committed
startup and shutdown replaced with lifespan
1 parent 444ce98 commit 2a60325

File tree

1 file changed

+49
-20
lines changed

1 file changed

+49
-20
lines changed

src/app/core/setup.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import AsyncGenerator, Callable
2+
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
13
from typing import Any
24

35
import anyio
@@ -68,6 +70,50 @@ async def set_threadpool_tokens(number_of_tokens: int = 100) -> None:
6870
limiter.total_tokens = number_of_tokens
6971

7072

73+
def lifespan_factory(
74+
settings: (
75+
DatabaseSettings
76+
| RedisCacheSettings
77+
| AppSettings
78+
| ClientSideCacheSettings
79+
| RedisQueueSettings
80+
| RedisRateLimiterSettings
81+
| EnvironmentSettings
82+
),
83+
create_tables_on_start: bool = True,
84+
) -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
85+
"""Factory to create a lifespan async context manager for a FastAPI app."""
86+
87+
@asynccontextmanager
88+
async def lifespan(app: FastAPI) -> AsyncGenerator:
89+
await set_threadpool_tokens()
90+
91+
if isinstance(settings, DatabaseSettings) and create_tables_on_start:
92+
await create_tables()
93+
94+
if isinstance(settings, RedisCacheSettings):
95+
await create_redis_cache_pool()
96+
97+
if isinstance(settings, RedisQueueSettings):
98+
await create_redis_queue_pool()
99+
100+
if isinstance(settings, RedisRateLimiterSettings):
101+
await create_redis_rate_limit_pool()
102+
103+
yield
104+
105+
if isinstance(settings, RedisCacheSettings):
106+
await close_redis_cache_pool()
107+
108+
if isinstance(settings, RedisQueueSettings):
109+
await close_redis_queue_pool()
110+
111+
if isinstance(settings, RedisRateLimiterSettings):
112+
await close_redis_rate_limit_pool()
113+
114+
return lifespan
115+
116+
71117
# -------------- application --------------
72118
def create_application(
73119
router: APIRouter,
@@ -136,30 +182,13 @@ def create_application(
136182
if isinstance(settings, EnvironmentSettings):
137183
kwargs.update({"docs_url": None, "redoc_url": None, "openapi_url": None})
138184

139-
application = FastAPI(**kwargs)
140-
141-
# --- application created ---
142-
application.include_router(router)
143-
application.add_event_handler("startup", set_threadpool_tokens)
185+
lifespan = lifespan_factory(settings, create_tables_on_start=create_tables_on_start)
144186

145-
if isinstance(settings, DatabaseSettings) and create_tables_on_start:
146-
application.add_event_handler("startup", create_tables)
147-
148-
if isinstance(settings, RedisCacheSettings):
149-
application.add_event_handler("startup", create_redis_cache_pool)
150-
application.add_event_handler("shutdown", close_redis_cache_pool)
187+
application = FastAPI(lifespan=lifespan, **kwargs)
151188

152189
if isinstance(settings, ClientSideCacheSettings):
153190
application.add_middleware(ClientCacheMiddleware, max_age=settings.CLIENT_CACHE_MAX_AGE)
154191

155-
if isinstance(settings, RedisQueueSettings):
156-
application.add_event_handler("startup", create_redis_queue_pool)
157-
application.add_event_handler("shutdown", close_redis_queue_pool)
158-
159-
if isinstance(settings, RedisRateLimiterSettings):
160-
application.add_event_handler("startup", create_redis_rate_limit_pool)
161-
application.add_event_handler("shutdown", close_redis_rate_limit_pool)
162-
163192
if isinstance(settings, EnvironmentSettings):
164193
if settings.ENVIRONMENT != EnvironmentOption.PRODUCTION:
165194
docs_router = APIRouter()
@@ -181,4 +210,4 @@ async def openapi() -> dict[str, Any]:
181210

182211
application.include_router(docs_router)
183212

184-
return application
213+
return application

0 commit comments

Comments
 (0)