1
+ from collections .abc import AsyncGenerator , Callable
2
+ from contextlib import _AsyncGeneratorContextManager , asynccontextmanager
1
3
from typing import Any
2
4
3
5
import anyio
@@ -68,6 +70,50 @@ async def set_threadpool_tokens(number_of_tokens: int = 100) -> None:
68
70
limiter .total_tokens = number_of_tokens
69
71
70
72
73
+ def lifespan_factory (
74
+ settings : (
75
+ DatabaseSettings
76
+ | RedisCacheSettings
77
+ | AppSettings
78
+ | ClientSideCacheSettings
79
+ | RedisQueueSettings
80
+ | RedisRateLimiterSettings
81
+ | EnvironmentSettings
82
+ ),
83
+ create_tables_on_start : bool = True ,
84
+ ) -> Callable [[FastAPI ], _AsyncGeneratorContextManager [Any ]]:
85
+ """Factory to create a lifespan async context manager for a FastAPI app."""
86
+
87
+ @asynccontextmanager
88
+ async def lifespan (app : FastAPI ) -> AsyncGenerator :
89
+ await set_threadpool_tokens ()
90
+
91
+ if isinstance (settings , DatabaseSettings ) and create_tables_on_start :
92
+ await create_tables ()
93
+
94
+ if isinstance (settings , RedisCacheSettings ):
95
+ await create_redis_cache_pool ()
96
+
97
+ if isinstance (settings , RedisQueueSettings ):
98
+ await create_redis_queue_pool ()
99
+
100
+ if isinstance (settings , RedisRateLimiterSettings ):
101
+ await create_redis_rate_limit_pool ()
102
+
103
+ yield
104
+
105
+ if isinstance (settings , RedisCacheSettings ):
106
+ await close_redis_cache_pool ()
107
+
108
+ if isinstance (settings , RedisQueueSettings ):
109
+ await close_redis_queue_pool ()
110
+
111
+ if isinstance (settings , RedisRateLimiterSettings ):
112
+ await close_redis_rate_limit_pool ()
113
+
114
+ return lifespan
115
+
116
+
71
117
# -------------- application --------------
72
118
def create_application (
73
119
router : APIRouter ,
@@ -136,30 +182,13 @@ def create_application(
136
182
if isinstance (settings , EnvironmentSettings ):
137
183
kwargs .update ({"docs_url" : None , "redoc_url" : None , "openapi_url" : None })
138
184
139
- application = FastAPI (** kwargs )
140
-
141
- # --- application created ---
142
- application .include_router (router )
143
- application .add_event_handler ("startup" , set_threadpool_tokens )
185
+ lifespan = lifespan_factory (settings , create_tables_on_start = create_tables_on_start )
144
186
145
- if isinstance (settings , DatabaseSettings ) and create_tables_on_start :
146
- application .add_event_handler ("startup" , create_tables )
147
-
148
- if isinstance (settings , RedisCacheSettings ):
149
- application .add_event_handler ("startup" , create_redis_cache_pool )
150
- application .add_event_handler ("shutdown" , close_redis_cache_pool )
187
+ application = FastAPI (lifespan = lifespan , ** kwargs )
151
188
152
189
if isinstance (settings , ClientSideCacheSettings ):
153
190
application .add_middleware (ClientCacheMiddleware , max_age = settings .CLIENT_CACHE_MAX_AGE )
154
191
155
- if isinstance (settings , RedisQueueSettings ):
156
- application .add_event_handler ("startup" , create_redis_queue_pool )
157
- application .add_event_handler ("shutdown" , close_redis_queue_pool )
158
-
159
- if isinstance (settings , RedisRateLimiterSettings ):
160
- application .add_event_handler ("startup" , create_redis_rate_limit_pool )
161
- application .add_event_handler ("shutdown" , close_redis_rate_limit_pool )
162
-
163
192
if isinstance (settings , EnvironmentSettings ):
164
193
if settings .ENVIRONMENT != EnvironmentOption .PRODUCTION :
165
194
docs_router = APIRouter ()
@@ -181,4 +210,4 @@ async def openapi() -> dict[str, Any]:
181
210
182
211
application .include_router (docs_router )
183
212
184
- return application
213
+ return application
0 commit comments