Skip to content

Commit 6f4c180

Browse files
committed
Parameterize more tests to use 2.x
1 parent a89485b commit 6f4c180

File tree

5 files changed

+206
-50
lines changed

5 files changed

+206
-50
lines changed

tests/test_extension_object.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
import sqlalchemy as sa
7+
import sqlalchemy.orm as sa_orm
78
from flask import Flask
89
from werkzeug.exceptions import NotFound
910

@@ -22,17 +23,45 @@ def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None:
2223
db.get_or_404(Todo, 2)
2324

2425

25-
def test_get_or_404_kwargs(app: Flask) -> None:
26+
def test_get_or_404_kwargs(app: Flask, model_class: t.Any) -> None:
2627
app.config["SQLALCHEMY_RECORD_QUERIES"] = True
27-
db = SQLAlchemy(app)
28+
db = SQLAlchemy(app, model_class=model_class)
2829

29-
class User(db.Model):
30-
id = sa.Column(db.Integer, primary_key=True) # type: ignore[var-annotated]
30+
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):
3131

32-
class Todo(db.Model):
33-
id = sa.Column(sa.Integer, primary_key=True)
34-
user_id = sa.Column(sa.ForeignKey(User.id)) # type: ignore[var-annotated]
35-
user = db.relationship(User)
32+
class User(db.Model): # type: ignore[no-redef]
33+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
34+
sa.Integer, primary_key=True, init=False
35+
)
36+
37+
class Todo(db.Model): # type: ignore[no-redef]
38+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
39+
sa.Integer, primary_key=True, init=False
40+
)
41+
user_id: sa_orm.Mapped[int] = sa_orm.mapped_column(
42+
sa.ForeignKey(User.id), init=False
43+
)
44+
user: sa_orm.Mapped[User] = sa_orm.relationship(User)
45+
46+
elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
47+
48+
class User(db.Model): # type: ignore[no-redef]
49+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
50+
51+
class Todo(db.Model): # type: ignore[no-redef]
52+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
53+
user_id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.ForeignKey(User.id))
54+
user: sa_orm.Mapped[User] = sa_orm.relationship(User)
55+
56+
else:
57+
58+
class User(db.Model): # type: ignore[no-redef]
59+
id = sa.Column(db.Integer, primary_key=True) # type: ignore[var-annotated]
60+
61+
class Todo(db.Model): # type: ignore[no-redef]
62+
id = sa.Column(sa.Integer, primary_key=True)
63+
user_id = sa.Column(sa.ForeignKey(User.id)) # type: ignore[var-annotated]
64+
user = db.relationship(User)
3665

3766
with app.app_context():
3867
db.create_all()

tests/test_model.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from flask_sqlalchemy.model import Model
1414

1515

16-
def test_default_model_class(app: Flask) -> None:
16+
def test_default_model_class_1x(app: Flask) -> None:
1717
db = SQLAlchemy(app)
1818

1919
assert db.Model.query_class is db.Query
@@ -22,7 +22,7 @@ def test_default_model_class(app: Flask) -> None:
2222
assert isinstance(db.Model, DefaultMeta)
2323

2424

25-
def test_custom_model_class(app: Flask) -> None:
25+
def test_custom_model_class_1x(app: Flask) -> None:
2626
class CustomModel(Model):
2727
pass
2828

@@ -33,7 +33,7 @@ class CustomModel(Model):
3333

3434
@pytest.mark.usefixtures("app_ctx")
3535
@pytest.mark.parametrize("base", [Model, object])
36-
def test_custom_declarative_class(app: Flask, base: t.Any) -> None:
36+
def test_custom_declarative_class_1x(app: Flask, base: t.Any) -> None:
3737
class CustomMeta(DefaultMeta):
3838
pass
3939

@@ -44,6 +44,42 @@ class CustomMeta(DefaultMeta):
4444
assert "query" in db.Model.__dict__
4545

4646

47+
def test_declarativebase_2x(app: Flask) -> None:
48+
class Base(sa_orm.DeclarativeBase):
49+
pass
50+
51+
db = SQLAlchemy(app, model_class=Base)
52+
assert issubclass(db.Model, sa_orm.DeclarativeBase)
53+
assert isinstance(db.Model, sa_orm.decl_api.DeclarativeAttributeIntercept)
54+
55+
56+
def test_declarativebasenometa_2x(app: Flask) -> None:
57+
class Base(sa_orm.DeclarativeBaseNoMeta):
58+
pass
59+
60+
db = SQLAlchemy(app, model_class=Base)
61+
assert issubclass(db.Model, sa_orm.DeclarativeBaseNoMeta)
62+
assert not isinstance(db.Model, sa_orm.decl_api.DeclarativeAttributeIntercept)
63+
64+
65+
def test_declarativebasemapped_2x(app: Flask) -> None:
66+
class Base(sa_orm.DeclarativeBase, sa_orm.MappedAsDataclass):
67+
pass
68+
69+
db = SQLAlchemy(app, model_class=Base)
70+
assert issubclass(db.Model, sa_orm.DeclarativeBase)
71+
assert isinstance(db.Model, sa_orm.decl_api.DCTransformDeclarative)
72+
73+
74+
def test_declarativebasenometamapped_2x(app: Flask) -> None:
75+
class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass):
76+
pass
77+
78+
db = SQLAlchemy(app, model_class=Base)
79+
assert issubclass(db.Model, sa_orm.DeclarativeBaseNoMeta)
80+
assert isinstance(db.Model, sa_orm.decl_api.DCTransformDeclarative)
81+
82+
4783
@pytest.mark.usefixtures("app_ctx")
4884
def test_model_repr(db: SQLAlchemy) -> None:
4985
class User(db.Model):

tests/test_record_queries.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
import sqlalchemy as sa
7+
import sqlalchemy.orm as sa_orm
78
from flask import Flask
89

910
from flask_sqlalchemy import SQLAlchemy
@@ -15,15 +16,35 @@ def test_query_info(app: Flask) -> None:
1516
app.config["SQLALCHEMY_RECORD_QUERIES"] = True
1617
db = SQLAlchemy(app)
1718

18-
class Example(db.Model):
19-
id = sa.Column(sa.Integer, primary_key=True)
19+
# Copied and pasted from conftest.py
20+
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):
21+
22+
class Todo(db.Model):
23+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
24+
sa.Integer, init=False, primary_key=True
25+
)
26+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(
27+
sa.String, nullable=True, default=None
28+
)
29+
30+
elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
31+
32+
class Todo(db.Model): # type: ignore[no-redef]
33+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
34+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True)
35+
36+
else:
37+
38+
class Todo(db.Model): # type: ignore[no-redef]
39+
id = sa.Column(sa.Integer, primary_key=True)
40+
title = sa.Column(sa.String)
2041

2142
db.create_all()
22-
db.session.execute(sa.select(Example).filter(Example.id < 5)).scalars()
43+
db.session.execute(sa.select(Todo).filter(Todo.id < 5)).scalars()
2344
info = get_recorded_queries()[-1]
2445
assert info.statement is not None
2546
assert "SELECT" in info.statement
26-
assert "FROM example" in info.statement
47+
assert "FROM todo" in info.statement
2748
assert info.parameters[0][0] == 5
2849
assert info.duration == info.end_time - info.start_time
2950
assert os.path.join("tests", "test_record_queries.py:") in info.location

tests/test_session.py

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
import typing as t
4+
35
import pytest
46
import sqlalchemy as sa
7+
import sqlalchemy.orm as sa_orm
58
from flask import Flask
69

710
from flask_sqlalchemy import SQLAlchemy
@@ -23,15 +26,15 @@ def test_scope(app: Flask, db: SQLAlchemy) -> None:
2326
assert first is not third
2427

2528

26-
def test_custom_scope(app: Flask) -> None:
29+
def test_custom_scope(app: Flask, model_class: t.Any) -> None:
2730
count = 0
2831

2932
def scope() -> int:
3033
nonlocal count
3134
count += 1
3235
return count
3336

34-
db = SQLAlchemy(app, session_options={"scopefunc": scope})
37+
db = SQLAlchemy(app, model_class=model_class, session_options={"scopefunc": scope})
3538

3639
with app.app_context():
3740
first = db.session()
@@ -42,47 +45,94 @@ def scope() -> int:
4245

4346

4447
@pytest.mark.usefixtures("app_ctx")
45-
def test_session_class(app: Flask) -> None:
48+
def test_session_class(app: Flask, model_class: t.Any) -> None:
4649
class CustomSession(Session):
4750
pass
4851

49-
db = SQLAlchemy(app, session_options={"class_": CustomSession})
52+
db = SQLAlchemy(
53+
app, model_class=model_class, session_options={"class_": CustomSession}
54+
)
5055
assert isinstance(db.session(), CustomSession)
5156

5257

5358
@pytest.mark.usefixtures("app_ctx")
54-
def test_session_uses_bind_key(app: Flask) -> None:
59+
def test_session_uses_bind_key(app: Flask, model_class: t.Any) -> None:
5560
app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"}
56-
db = SQLAlchemy(app)
61+
db = SQLAlchemy(app, model_class=model_class)
5762

58-
class User(db.Model):
59-
id = sa.Column(sa.Integer, primary_key=True)
63+
if issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
6064

61-
class Post(db.Model):
62-
__bind_key__ = "a"
63-
id = sa.Column(sa.Integer, primary_key=True)
65+
class User(db.Model):
66+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
6467

65-
assert db.session.get_bind(mapper=User) is db.engine
66-
assert db.session.get_bind(mapper=Post) is db.engines["a"]
68+
class Post(db.Model):
69+
__bind_key__ = "a"
70+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
6771

72+
else:
6873

69-
@pytest.mark.usefixtures("app_ctx")
70-
def test_get_bind_inheritance(app: Flask) -> None:
71-
app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"}
72-
db = SQLAlchemy(app)
74+
class User(db.Model): # type: ignore[no-redef]
75+
id = sa.Column(sa.Integer, primary_key=True)
7376

74-
class User(db.Model):
75-
__bind_key__ = "a"
76-
id = sa.Column(sa.Integer, primary_key=True)
77-
type = sa.Column(sa.String, nullable=False)
77+
class Post(db.Model): # type: ignore[no-redef]
78+
__bind_key__ = "a"
79+
id = sa.Column(sa.Integer, primary_key=True)
7880

79-
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}
81+
assert db.session.get_bind(mapper=User) is db.engine
82+
assert db.session.get_bind(mapper=Post) is db.engines["a"]
8083

81-
class Admin(User):
82-
id = sa.Column(sa.ForeignKey(User.id), primary_key=True)
83-
org = sa.Column(sa.String, nullable=False)
8484

85-
__mapper_args__ = {"polymorphic_identity": "admin"}
85+
@pytest.mark.usefixtures("app_ctx")
86+
def test_get_bind_inheritance(app: Flask, model_class: t.Any) -> None:
87+
app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"}
88+
db = SQLAlchemy(app, model_class=model_class)
89+
90+
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):
91+
92+
class User(db.Model):
93+
__bind_key__ = "a"
94+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
95+
sa.Integer, primary_key=True, init=False
96+
)
97+
type: sa_orm.Mapped[str] = sa_orm.mapped_column(
98+
sa.String, nullable=False, init=False
99+
)
100+
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}
101+
102+
class Admin(User):
103+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
104+
sa.ForeignKey(User.id), primary_key=True, init=False
105+
)
106+
org: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False)
107+
__mapper_args__ = {"polymorphic_identity": "admin"}
108+
109+
elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
110+
111+
class User(db.Model):
112+
__bind_key__ = "a"
113+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
114+
type: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False)
115+
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}
116+
117+
class Admin(User):
118+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
119+
sa.ForeignKey(User.id), primary_key=True
120+
)
121+
org: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False)
122+
__mapper_args__ = {"polymorphic_identity": "admin"}
123+
124+
else:
125+
126+
class User(db.Model): # type: ignore[no-redef]
127+
__bind_key__ = "a"
128+
id = sa.Column(sa.Integer, primary_key=True)
129+
type = sa.Column(sa.String, nullable=False)
130+
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}
131+
132+
class Admin(User): # type: ignore[no-redef]
133+
id = sa.Column(sa.ForeignKey(User.id), primary_key=True)
134+
org = sa.Column(sa.String, nullable=False)
135+
__mapper_args__ = {"polymorphic_identity": "admin"}
86136

87137
db.create_all()
88138
db.session.add(Admin(org="pallets"))

0 commit comments

Comments
 (0)