11from  __future__ import  annotations 
22
3+ import  typing  as  t 
4+ 
35import  pytest 
46import  sqlalchemy  as  sa 
7+ import  sqlalchemy .orm  as  sa_orm 
58from  flask  import  Flask 
69
710from  flask_sqlalchemy  import  SQLAlchemy 
@@ -23,15 +26,15 @@ def test_scope(app: Flask, db: SQLAlchemy) -> None:
2326        assert  first  is  not third 
2427
2528
26- def  test_custom_scope (app : Flask ) ->  None :
29+ def  test_custom_scope (app : Flask ,  model_class :  t . Any ) ->  None :
2730    count  =  0 
2831
2932    def  scope () ->  int :
3033        nonlocal  count 
3134        count  +=  1 
3235        return  count 
3336
34-     db  =  SQLAlchemy (app , session_options = {"scopefunc" : scope })
37+     db  =  SQLAlchemy (app , model_class = model_class ,  session_options = {"scopefunc" : scope })
3538
3639    with  app .app_context ():
3740        first  =  db .session ()
@@ -42,47 +45,94 @@ def scope() -> int:
4245
4346
4447@pytest .mark .usefixtures ("app_ctx" ) 
45- def  test_session_class (app : Flask ) ->  None :
48+ def  test_session_class (app : Flask ,  model_class :  t . Any ) ->  None :
4649    class  CustomSession (Session ):
4750        pass 
4851
49-     db  =  SQLAlchemy (app , session_options = {"class_" : CustomSession })
52+     db  =  SQLAlchemy (
53+         app , model_class = model_class , session_options = {"class_" : CustomSession }
54+     )
5055    assert  isinstance (db .session (), CustomSession )
5156
5257
5358@pytest .mark .usefixtures ("app_ctx" ) 
54- def  test_session_uses_bind_key (app : Flask ) ->  None :
59+ def  test_session_uses_bind_key (app : Flask ,  model_class :  t . Any ) ->  None :
5560    app .config ["SQLALCHEMY_BINDS" ] =  {"a" : "sqlite://" }
56-     db  =  SQLAlchemy (app )
61+     db  =  SQLAlchemy (app ,  model_class = model_class )
5762
58-     class  User (db .Model ):
59-         id  =  sa .Column (sa .Integer , primary_key = True )
63+     if  issubclass (db .Model , (sa_orm .DeclarativeBase , sa_orm .DeclarativeBaseNoMeta )):
6064
61-     class  Post (db .Model ):
62-         __bind_key__  =  "a" 
63-         id  =  sa .Column (sa .Integer , primary_key = True )
65+         class  User (db .Model ):
66+             id : sa_orm .Mapped [int ] =  sa_orm .mapped_column (sa .Integer , primary_key = True )
6467
65-     assert  db .session .get_bind (mapper = User ) is  db .engine 
66-     assert  db .session .get_bind (mapper = Post ) is  db .engines ["a" ]
68+         class  Post (db .Model ):
69+             __bind_key__  =  "a" 
70+             id : sa_orm .Mapped [int ] =  sa_orm .mapped_column (sa .Integer , primary_key = True )
6771
72+     else :
6873
69- @pytest .mark .usefixtures ("app_ctx" ) 
70- def  test_get_bind_inheritance (app : Flask ) ->  None :
71-     app .config ["SQLALCHEMY_BINDS" ] =  {"a" : "sqlite://" }
72-     db  =  SQLAlchemy (app )
74+         class  User (db .Model ):  # type: ignore[no-redef] 
75+             id  =  sa .Column (sa .Integer , primary_key = True )
7376
74-     class  User (db .Model ):
75-         __bind_key__  =  "a" 
76-         id  =  sa .Column (sa .Integer , primary_key = True )
77-         type  =  sa .Column (sa .String , nullable = False )
77+         class  Post (db .Model ):  # type: ignore[no-redef] 
78+             __bind_key__  =  "a" 
79+             id  =  sa .Column (sa .Integer , primary_key = True )
7880
79-         __mapper_args__  =  {"polymorphic_on" : type , "polymorphic_identity" : "user" }
81+     assert  db .session .get_bind (mapper = User ) is  db .engine 
82+     assert  db .session .get_bind (mapper = Post ) is  db .engines ["a" ]
8083
81-     class  Admin (User ):
82-         id  =  sa .Column (sa .ForeignKey (User .id ), primary_key = True )
83-         org  =  sa .Column (sa .String , nullable = False )
8484
85-         __mapper_args__  =  {"polymorphic_identity" : "admin" }
85+ @pytest .mark .usefixtures ("app_ctx" ) 
86+ def  test_get_bind_inheritance (app : Flask , model_class : t .Any ) ->  None :
87+     app .config ["SQLALCHEMY_BINDS" ] =  {"a" : "sqlite://" }
88+     db  =  SQLAlchemy (app , model_class = model_class )
89+ 
90+     if  issubclass (db .Model , (sa_orm .MappedAsDataclass )):
91+ 
92+         class  User (db .Model ):
93+             __bind_key__  =  "a" 
94+             id : sa_orm .Mapped [int ] =  sa_orm .mapped_column (
95+                 sa .Integer , primary_key = True , init = False 
96+             )
97+             type : sa_orm .Mapped [str ] =  sa_orm .mapped_column (
98+                 sa .String , nullable = False , init = False 
99+             )
100+             __mapper_args__  =  {"polymorphic_on" : type , "polymorphic_identity" : "user" }
101+ 
102+         class  Admin (User ):
103+             id : sa_orm .Mapped [int ] =  sa_orm .mapped_column (
104+                 sa .ForeignKey (User .id ), primary_key = True , init = False 
105+             )
106+             org : sa_orm .Mapped [str ] =  sa_orm .mapped_column (sa .String , nullable = False )
107+             __mapper_args__  =  {"polymorphic_identity" : "admin" }
108+ 
109+     elif  issubclass (db .Model , (sa_orm .DeclarativeBase , sa_orm .DeclarativeBaseNoMeta )):
110+ 
111+         class  User (db .Model ):
112+             __bind_key__  =  "a" 
113+             id : sa_orm .Mapped [int ] =  sa_orm .mapped_column (sa .Integer , primary_key = True )
114+             type : sa_orm .Mapped [str ] =  sa_orm .mapped_column (sa .String , nullable = False )
115+             __mapper_args__  =  {"polymorphic_on" : type , "polymorphic_identity" : "user" }
116+ 
117+         class  Admin (User ):
118+             id : sa_orm .Mapped [int ] =  sa_orm .mapped_column (
119+                 sa .ForeignKey (User .id ), primary_key = True 
120+             )
121+             org : sa_orm .Mapped [str ] =  sa_orm .mapped_column (sa .String , nullable = False )
122+             __mapper_args__  =  {"polymorphic_identity" : "admin" }
123+ 
124+     else :
125+ 
126+         class  User (db .Model ):  # type: ignore[no-redef] 
127+             __bind_key__  =  "a" 
128+             id  =  sa .Column (sa .Integer , primary_key = True )
129+             type  =  sa .Column (sa .String , nullable = False )
130+             __mapper_args__  =  {"polymorphic_on" : type , "polymorphic_identity" : "user" }
131+ 
132+         class  Admin (User ):  # type: ignore[no-redef] 
133+             id  =  sa .Column (sa .ForeignKey (User .id ), primary_key = True )
134+             org  =  sa .Column (sa .String , nullable = False )
135+             __mapper_args__  =  {"polymorphic_identity" : "admin" }
86136
87137    db .create_all ()
88138    db .session .add (Admin (org = "pallets" ))
0 commit comments