Skip to content

Commit bcaf0eb

Browse files
authored
Upgrade to SQLAlchemy 2.0 (#197)
## Ticket #82 ## Changes Upgraded to SQLAlchemy 2.0, see context section for major changes in this version ## Context for reviewers SQLAlchemy 2.0 comes with an incredibly thorough migration guide: https://docs.sqlalchemy.org/en/20/changelog/migration_20.html Thankfully 1.4 started nudging usage of SQLAlchemy to use the new approaches, so the changes needed weren't that large. Noteworthy changes: * DB model definitions have been adjusted. The python types now need to be wrapped in `Mapped[...]`, and columns now get defined either with `mapped_column` OR nothing with SQLAlchemy figuring out the type from the python type (similar to Dataclasses or Pydantic). This largely results in more minimal class definitions especially ones with a bunch of basic columns. * DB session/connection/engine logic adjusted slightly to have fewer edge cases. For the changes here that just required some adjustments to the underlying connection setup, something I'd done on another project already to fix some connection issues. Largely this comes down to connections no longer auto-committing and needing to be in `with conn.begin()` blocks to function properly. * Raw SQL calls to `execute` should be wrapped in the `text` class * DeclarativeBase is better supported and instantiated slightly differently. The Metadata object no longer works globally and has to be attached to the DeclarativeBase to work properly. * Typing for MyPy is built into SQLAlchemy and no longer relies on libraries like SQLAlchemy-stubs ## Testing Since the DB hits virtually everything, I tested everything as thoroughly as possible. Because of how the SQLAlchemy-stubs we removed work, they need to be completely deleted otherwise MyPy thinks they should run. If you are running outside of the docker container, run `poetry install --no-root --all-extras --with dev --sync` which will delete extra packages. ### Basics Tests, formatting, linting, all working, only required a few fixes to make SQLAlchemy happy. ### SQLAlchemy warnings When running unit tests, SQLAlchemy will output warnings for deprecated features. The only still usable deprecated feature we had was our `get`queries which were adjusted. Prior to the fix, we would see this warning: ![Screenshot 2023-09-20 at 11 00 38 AM](https://github.com/navapbc/template-application-flask/assets/46358556/7bc1d0f8-15df-4d15-b991-6c57123e70c2) ### Migrations They still work, adding a few new columns generates what we would expect and uses the mapping config added to the Base class. ![Screenshot 2023-09-19 at 3 36 16 PM](https://github.com/navapbc/template-application-flask/assets/46358556/3d3bd20b-618a-473f-9e44-a8e52a6739f5) ![Screenshot 2023-09-19 at 3 36 11 PM](https://github.com/navapbc/template-application-flask/assets/46358556/273f1597-67b4-460d-bb2a-44f512781c06) ### Swagger Was able to successfully use the local swagger endpoints and create/update/read from the DB which was populated like so: <img width="1050" alt="Screenshot 2023-09-20 at 11 34 39 AM" src="https://github.com/navapbc/template-application-flask/assets/46358556/049c8531-8de4-4406-a3b7-9ed92d92da0e">
1 parent c1d1359 commit bcaf0eb

File tree

12 files changed

+134
-152
lines changed

12 files changed

+134
-152
lines changed

app/poetry.lock

Lines changed: 67 additions & 68 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

app/pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = ["Nava Engineering <engineering@navapbc.com>"]
77

88
[tool.poetry.dependencies]
99
python = "^3.10"
10-
SQLAlchemy = {extras = ["mypy"], version = "^1.4.40"}
10+
SQLAlchemy = {extras = ["mypy"], version = "2.0"}
1111
alembic = "^1.8.1"
1212
psycopg2-binary = "^2.9.3"
1313
python-dotenv = "^0.20.0"
@@ -37,6 +37,7 @@ bandit = "^1.7.4"
3737
pytest = "^6.0.0"
3838
pytest-watch = "^4.2.0"
3939
pytest-lazy-fixture = "^0.6.3"
40+
types-pyyaml = "^6.0.12.11"
4041

4142
[build-system]
4243
requires = ["poetry-core>=1.0.0"]
@@ -80,8 +81,6 @@ warn_redundant_casts = true
8081
warn_unreachable = true
8182
warn_unused_ignores = true
8283

83-
plugins = ["sqlalchemy.ext.mypy.plugin"]
84-
8584
[tool.bandit]
8685
# Ignore audit logging test file since test audit logging requires a lot of operations that trigger bandit warnings
8786
exclude_dirs = ["./tests/src/logging/test_audit.py"]

app/src/adapters/db/clients/postgres_client.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ def get_conn() -> Any:
4646
return sqlalchemy.create_engine(
4747
"postgresql://",
4848
pool=conn_pool,
49-
# FYI, execute many mode handles how SQLAlchemy handles doing a bunch of inserts/updates/deletes at once
50-
# https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#psycopg2-fast-execution-helpers
51-
executemany_mode="batch",
5249
hide_parameters=db_config.hide_sql_parameter_logs,
5350
# TODO: Don't think we need this as we aren't using JSON columns, but keeping for reference
5451
# json_serializer=lambda o: json.dumps(o, default=pydantic.json.pydantic_encoder),

app/src/db/migrations/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def include_object(
3232
object: sqlalchemy.schema.SchemaItem,
33-
name: str,
33+
name: str | None,
3434
type_: str,
3535
reflected: bool,
3636
compare_to: Any,

app/src/db/migrations/run.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
# Convenience script for running alembic migration commands through a pyscript
22
# rather than the command line. This allows poetry to package and alias it for
33
# running on the production docker image from any directory.
4-
import itertools
54
import logging
65
import os
7-
from typing import Optional
86

97
import alembic.command as command
108
import alembic.script as script
119
import sqlalchemy
1210
from alembic.config import Config
13-
from alembic.operations.ops import MigrationScript
1411
from alembic.runtime import migration
1512

1613
logger = logging.getLogger(__name__)
@@ -53,41 +50,3 @@ def have_all_migrations_run(db_engine: sqlalchemy.engine.Engine) -> None:
5350
logger.info(
5451
f"The current migration head is up to date, {current_heads} and Alembic is expecting {expected_heads}"
5552
)
56-
57-
58-
def check_model_parity() -> None:
59-
revisions: list[MigrationScript] = []
60-
61-
def process_revision_directives(
62-
context: migration.MigrationContext,
63-
revision: Optional[str],
64-
directives: list[MigrationScript],
65-
) -> None:
66-
nonlocal revisions
67-
revisions = list(directives)
68-
# Prevent actually generating a migration
69-
directives[:] = []
70-
71-
command.revision(
72-
config=alembic_cfg,
73-
autogenerate=True,
74-
process_revision_directives=process_revision_directives,
75-
)
76-
diff = list(
77-
itertools.chain.from_iterable(
78-
op.as_diffs() for script in revisions for op in script.upgrade_ops_list
79-
)
80-
)
81-
82-
message = (
83-
"The application models are not in sync with the migrations. You should generate "
84-
"a new automigration or update your local migration file. "
85-
"If there are unexpected errors you may need to merge main into your branch."
86-
)
87-
88-
if diff:
89-
for line in diff:
90-
print("::error title=Missing migration::Missing migration:", line)
91-
92-
logger.error(message, extra={"issues": str(diff)})
93-
raise Exception(message)

app/src/db/models/base.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from typing import Any
55
from uuid import UUID
66

7-
from sqlalchemy import TIMESTAMP, Column, MetaData, inspect
7+
from sqlalchemy import TIMESTAMP, MetaData, Text, inspect
88
from sqlalchemy.dialects import postgresql
9-
from sqlalchemy.ext.declarative import as_declarative
10-
from sqlalchemy.orm import declarative_mixin
9+
from sqlalchemy.orm import DeclarativeBase, Mapped, declarative_mixin, mapped_column
1110
from sqlalchemy.sql.functions import now as sqlnow
1211

1312
from src.util import datetime_util
@@ -26,10 +25,33 @@
2625
)
2726

2827

29-
@as_declarative(metadata=metadata)
30-
class Base:
28+
class Base(DeclarativeBase):
29+
# Attach the metadata to the Base class so all tables automatically get added to the metadata
30+
metadata = metadata
31+
32+
# Override the default type that SQLAlchemy will map python types to.
33+
# This is used if you simply define a column like:
34+
#
35+
# my_column: Mapped[str]
36+
#
37+
# If you provide a mapped_column attribute you can override these values
38+
#
39+
# See: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html#mapped-column-derives-the-datatype-and-nullability-from-the-mapped-annotation
40+
# for the default mappings
41+
#
42+
# See: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html#orm-declarative-mapped-column-type-map
43+
# for details on setting up this configuration.
44+
type_annotation_map = {
45+
# Always include a timezone for datetimes
46+
datetime: TIMESTAMP(timezone=True),
47+
# Explicitly use the Text column type for strings
48+
str: Text,
49+
# Always use the Postgres UUID column type
50+
uuid.UUID: postgresql.UUID(as_uuid=True),
51+
}
52+
3153
def _dict(self) -> dict:
32-
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}
54+
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs} # type: ignore
3355

3456
def for_json(self) -> dict:
3557
json_valid_dict = {}
@@ -46,9 +68,9 @@ def for_json(self) -> dict:
4668

4769
def copy(self, **kwargs: dict[str, Any]) -> "Base":
4870
# TODO - Python 3.11 will let us make the return Self instead
49-
table = self.__table__ # type: ignore
71+
table = self.__table__
5072
non_pk_columns = [
51-
k for k in table.columns.keys() if k not in table.primary_key.columns.keys()
73+
k for k in table.columns.keys() if k not in table.primary_key.columns.keys() # type: ignore
5274
]
5375
data = {c: getattr(self, c) for c in non_pk_columns}
5476
data.update(kwargs)
@@ -59,10 +81,10 @@ def copy(self, **kwargs: dict[str, Any]) -> "Base":
5981
@declarative_mixin
6082
class IdMixin:
6183
"""Mixin to add a UUID id primary key column to a model
62-
https://docs.sqlalchemy.org/en/14/orm/declarative_mixins.html
84+
https://docs.sqlalchemy.org/en/20/orm/declarative_mixins.html
6385
"""
6486

65-
id: uuid.UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
87+
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
6688

6789

6890
def same_as_created_at(context: Any) -> Any:
@@ -72,18 +94,16 @@ def same_as_created_at(context: Any) -> Any:
7294
@declarative_mixin
7395
class TimestampMixin:
7496
"""Mixin to add created_at and updated_at columns to a model
75-
https://docs.sqlalchemy.org/en/14/orm/declarative_mixins.html#mixing-in-columns
97+
https://docs.sqlalchemy.org/en/20/orm/declarative_mixins.html#mixing-in-columns
7698
"""
7799

78-
created_at: datetime = Column(
79-
TIMESTAMP(timezone=True),
100+
created_at: Mapped[datetime] = mapped_column(
80101
nullable=False,
81102
default=datetime_util.utcnow,
82103
server_default=sqlnow(),
83104
)
84105

85-
updated_at: datetime = Column(
86-
TIMESTAMP(timezone=True),
106+
updated_at: Mapped[datetime] = mapped_column(
87107
nullable=False,
88108
default=same_as_created_at,
89109
onupdate=datetime_util.utcnow,

app/src/db/models/user_models.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from typing import Optional
55
from uuid import UUID
66

7-
from sqlalchemy import Boolean, Column, Date, Enum, ForeignKey, Text
8-
from sqlalchemy.dialects import postgresql
9-
from sqlalchemy.orm import Mapped, relationship
7+
from sqlalchemy import Enum, ForeignKey
8+
from sqlalchemy.orm import Mapped, mapped_column, relationship
109

1110
from src.db.models.base import Base, IdMixin, TimestampMixin
1211

@@ -21,22 +20,22 @@ class RoleType(str, enum.Enum):
2120
class User(Base, IdMixin, TimestampMixin):
2221
__tablename__ = "user"
2322

24-
first_name: str = Column(Text, nullable=False)
25-
middle_name: Optional[str] = Column(Text)
26-
last_name: str = Column(Text, nullable=False)
27-
phone_number: str = Column(Text, nullable=False)
28-
date_of_birth: date = Column(Date, nullable=False)
29-
is_active: bool = Column(Boolean, nullable=False)
23+
first_name: Mapped[str]
24+
middle_name: Mapped[Optional[str]]
25+
last_name: Mapped[str]
26+
phone_number: Mapped[str]
27+
date_of_birth: Mapped[date]
28+
is_active: Mapped[bool]
3029

31-
roles: list["Role"] = relationship(
30+
roles: Mapped[list["Role"]] = relationship(
3231
"Role", back_populates="user", cascade="all, delete", order_by="Role.type"
3332
)
3433

3534

3635
class Role(Base, TimestampMixin):
3736
__tablename__ = "role"
38-
user_id: Mapped[UUID] = Column(
39-
postgresql.UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
37+
user_id: Mapped[UUID] = mapped_column(
38+
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
4039
)
4140

4241
# Set native_enum=False to use store enum values as VARCHAR/TEXT
@@ -48,6 +47,6 @@ class Role(Base, TimestampMixin):
4847
# not yet functional
4948
# (See https://github.com/sqlalchemy/alembic/issues/363)
5049
#
51-
# https://docs.sqlalchemy.org/en/14/core/type_basics.html#sqlalchemy.types.Enum.params.native_enum
52-
type: RoleType = Column(Enum(RoleType, native_enum=False), primary_key=True)
53-
user: User = relationship(User, back_populates="roles")
50+
# https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Enum.params.native_enum
51+
type: Mapped[RoleType] = mapped_column(Enum(RoleType, native_enum=False), primary_key=True)
52+
user: Mapped[User] = relationship(User, back_populates="roles")

app/src/services/users/get_user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# https://github.com/navapbc/template-application-flask/issues/52
1212
def get_user(db_session: Session, user_id: str) -> User:
1313
# TODO: move this to service and/or persistence layer
14-
result = db_session.query(User).options(orm.selectinload(User.roles)).get(user_id)
14+
result = db_session.get(User, user_id, options=[orm.selectinload(User.roles)])
1515

1616
if result is None:
1717
# TODO move HTTP related logic out of service layer to controller layer and just return None from here

app/src/services/users/patch_user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def patch_user(
3333

3434
with db_session.begin():
3535
# TODO: move this to service and/or persistence layer
36-
user = db_session.query(User).options(orm.selectinload(User.roles)).get(user_id)
36+
user = db_session.get(User, user_id, options=[orm.selectinload(User.roles)])
3737

3838
if user is None:
3939
# TODO move HTTP related logic out of service layer to controller layer and just return None from here

app/tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def db_client(monkeypatch_session) -> db.DBClient:
8080
"""
8181

8282
with db_testing.create_isolated_db(monkeypatch_session) as db_client:
83-
models.metadata.create_all(bind=db_client.get_connection())
83+
with db_client.get_connection() as conn, conn.begin():
84+
models.metadata.create_all(bind=conn)
8485
yield db_client
8586

8687

0 commit comments

Comments
 (0)