1
1
import asyncio
2
2
from contextvars import ContextVar
3
- from typing import Dict , Optional , Union
3
+ from typing import Dict , Optional , Type , Union
4
4
5
- from sqlalchemy .engine import Engine
6
5
from sqlalchemy .engine .url import URL
7
- from sqlalchemy .ext .asyncio import AsyncSession , create_async_engine
6
+ from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , create_async_engine
8
7
from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
9
8
from starlette .requests import Request
10
9
from starlette .types import ASGIApp
11
10
12
- from fastapi_async_sqlalchemy .exceptions import MissingSessionError , SessionNotInitialisedError
11
+ from fastapi_async_sqlalchemy .exceptions import (
12
+ MissingSessionError ,
13
+ SessionNotInitialisedError ,
14
+ )
13
15
14
16
try :
15
- from sqlalchemy .ext .asyncio import async_sessionmaker # noqa: F811
17
+ from sqlalchemy .ext .asyncio import async_sessionmaker
16
18
except ImportError :
17
- from sqlalchemy .orm import sessionmaker as async_sessionmaker
19
+ from sqlalchemy .orm import sessionmaker as async_sessionmaker # type: ignore
18
20
19
21
# Try to import SQLModel's AsyncSession which has the .exec() method
20
22
try :
21
23
from sqlmodel .ext .asyncio .session import AsyncSession as SQLModelAsyncSession
22
24
23
- DefaultAsyncSession = SQLModelAsyncSession
25
+ DefaultAsyncSession : Type [ AsyncSession ] = SQLModelAsyncSession # type: ignore
24
26
except ImportError :
25
- DefaultAsyncSession = AsyncSession
27
+ DefaultAsyncSession : Type [ AsyncSession ] = AsyncSession # type: ignore
26
28
27
29
28
- def create_middleware_and_session_proxy ():
30
+ def create_middleware_and_session_proxy () -> tuple :
29
31
_Session : Optional [async_sessionmaker ] = None
30
- _session : ContextVar [Optional [DefaultAsyncSession ]] = ContextVar ("_session" , default = None )
32
+ _session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
31
33
_multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
32
34
_commit_on_exit_ctx : ContextVar [bool ] = ContextVar ("_commit_on_exit_ctx" , default = False )
33
35
# Usage of context vars inside closures is not recommended, since they are not properly
@@ -39,9 +41,9 @@ def __init__(
39
41
self ,
40
42
app : ASGIApp ,
41
43
db_url : Optional [Union [str , URL ]] = None ,
42
- custom_engine : Optional [Engine ] = None ,
43
- engine_args : Dict = None ,
44
- session_args : Dict = None ,
44
+ custom_engine : Optional [AsyncEngine ] = None ,
45
+ engine_args : Optional [ Dict ] = None ,
46
+ session_args : Optional [ Dict ] = None ,
45
47
commit_on_exit : bool = False ,
46
48
):
47
49
super ().__init__ (app )
@@ -52,13 +54,18 @@ def __init__(
52
54
if not custom_engine and not db_url :
53
55
raise ValueError ("You need to pass a db_url or a custom_engine parameter." )
54
56
if not custom_engine :
57
+ if db_url is None :
58
+ raise ValueError ("db_url cannot be None when custom_engine is not provided" )
55
59
engine = create_async_engine (db_url , ** engine_args )
56
60
else :
57
61
engine = custom_engine
58
62
59
63
nonlocal _Session
60
64
_Session = async_sessionmaker (
61
- engine , class_ = DefaultAsyncSession , expire_on_commit = False , ** session_args
65
+ engine ,
66
+ class_ = DefaultAsyncSession ,
67
+ expire_on_commit = False ,
68
+ ** session_args ,
62
69
)
63
70
64
71
async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ):
@@ -67,7 +74,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
67
74
68
75
class DBSessionMeta (type ):
69
76
@property
70
- def session (self ) -> DefaultAsyncSession :
77
+ def session (self ) -> AsyncSession :
71
78
"""Return an instance of Session local to the current async context."""
72
79
if _Session is None :
73
80
raise SessionNotInitialisedError
@@ -123,7 +130,7 @@ async def cleanup():
123
130
class DBSession (metaclass = DBSessionMeta ):
124
131
def __init__ (
125
132
self ,
126
- session_args : Dict = None ,
133
+ session_args : Optional [ Dict ] = None ,
127
134
commit_on_exit : bool = False ,
128
135
multi_sessions : bool = False ,
129
136
):
0 commit comments