Skip to content

Commit 7b79d23

Browse files
authored
feat: soft-deletion of database objects (#102)
* feat: database softdelete * feat: cascade delete * feat: add cascading delete on BiocommonsGroup and Platform * feat: add migration script * fix: add warning to migration
1 parent 59e7707 commit 7b79d23

File tree

6 files changed

+1103
-22
lines changed

6 files changed

+1103
-22
lines changed

db/core.py

Lines changed: 161 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
from typing import ClassVar
1+
from __future__ import annotations
22

3-
from sqlalchemy import MetaData
4-
from sqlmodel import SQLModel
3+
from typing import Any, ClassVar
4+
5+
from sqlalchemy import MetaData, event, select
6+
from sqlalchemy import inspect as sa_inspect
7+
from sqlalchemy.exc import IntegrityError
8+
from sqlalchemy.orm import Session as SASession
9+
from sqlalchemy.orm import with_loader_criteria
10+
from sqlalchemy.sql import expression
11+
from sqlmodel import Field, Session, SQLModel
512

613
naming_convention = {
714
"ix": "ix_%(column_0_label)s",
@@ -27,4 +34,154 @@ class BaseModel(SQLModel):
2734
metadata: ClassVar[MetaData] = metadata
2835

2936

30-
__all__ = ["BaseModel"]
37+
class SoftDeleteModel(BaseModel):
38+
"""
39+
Base for ORM models that should support soft deletion.
40+
"""
41+
__abstract__ = True
42+
43+
is_deleted: bool = Field(
44+
default=False,
45+
nullable=False,
46+
index=True,
47+
sa_column_kwargs={"server_default": expression.false()},
48+
description="Soft-delete flag. True means the row is hidden from default queries.",
49+
)
50+
51+
def delete(self, session: Session, commit: bool = False) -> "SoftDeleteModel":
52+
"""
53+
Soft delete this record (mark as deleted without removing from DB).
54+
"""
55+
self.is_deleted = True
56+
session.add(self)
57+
if commit:
58+
session.commit()
59+
session.expunge(self)
60+
return self
61+
62+
def restore(self, session: Session, commit: bool = False) -> "SoftDeleteModel":
63+
"""
64+
Restore a previously deleted record.
65+
"""
66+
self.is_deleted = False
67+
session.add(self)
68+
if commit:
69+
session.commit()
70+
return self
71+
72+
@classmethod
73+
def get_deleted_by_id(cls, session: Session, identity: Any) -> "SoftDeleteModel | None":
74+
"""
75+
Retrieve a soft-deleted record by primary key.
76+
"""
77+
identity_dict = cls._coerce_primary_key_map(identity)
78+
stmt = (
79+
select(cls)
80+
.execution_options(include_deleted=True)
81+
.filter_by(**identity_dict)
82+
.where(cls.is_deleted.is_(True))
83+
)
84+
return session.exec(stmt).scalars().one_or_none()
85+
86+
@classmethod
87+
def _coerce_primary_key_map(cls, identity: Any) -> dict[str, Any]:
88+
"""
89+
Coerce arbitrary primary key identifiers into a ``{column: value}`` mapping.
90+
91+
``identity`` may be provided as:
92+
* a dict where keys match the primary key columns,
93+
* a single scalar value when the model uses a single-column primary key,
94+
* a tuple/list containing values for each primary key column in order.
95+
96+
Any mismatch between the provided structure and the model's primary key
97+
definition raises ``ValueError`` so downstream queries remain predictable.
98+
"""
99+
mapper = sa_inspect(cls)
100+
pk_cols = mapper.primary_key
101+
if not pk_cols:
102+
raise ValueError(f"{cls.__name__} does not have a primary key defined.")
103+
104+
if isinstance(identity, dict):
105+
return identity
106+
107+
if len(pk_cols) == 1 and not isinstance(identity, (tuple, list)):
108+
return {pk_cols[0].key: identity}
109+
110+
if isinstance(identity, (tuple, list)):
111+
if len(identity) != len(pk_cols):
112+
raise ValueError(
113+
f"Identity length {len(identity)} does not match primary key length {len(pk_cols)}"
114+
)
115+
return {col.key: value for col, value in zip(pk_cols, identity)}
116+
117+
raise ValueError("Identity must be scalar, tuple/list, or dict matching the primary key.")
118+
119+
120+
def _copy_column_state(source: SoftDeleteModel, target: SoftDeleteModel) -> None:
121+
mapper = sa_inspect(source.__class__)
122+
for attr in mapper.column_attrs:
123+
key = attr.key
124+
if key == "is_deleted":
125+
continue
126+
setattr(target, key, getattr(source, key))
127+
128+
129+
def _identity_dict_from_instance(instance: SoftDeleteModel) -> dict[str, Any] | None:
130+
mapper = sa_inspect(instance.__class__)
131+
identity: dict[str, Any] = {}
132+
for column in mapper.primary_key:
133+
value = getattr(instance, column.key, None)
134+
if value is None:
135+
return None
136+
identity[column.key] = value
137+
return identity
138+
139+
140+
def _soft_delete_filter(cls) -> Any:
141+
mapper = sa_inspect(cls, raiseerr=False)
142+
if mapper is None:
143+
return expression.true()
144+
column = mapper.c.is_deleted
145+
return column.is_(False)
146+
147+
148+
@event.listens_for(SASession, "before_flush")
149+
def _revive_soft_deleted(session: SASession, flush_context, instances) -> None:
150+
for obj in list(session.new):
151+
if not isinstance(obj, SoftDeleteModel):
152+
continue
153+
identity_dict = _identity_dict_from_instance(obj)
154+
if identity_dict is None:
155+
continue
156+
stmt = (
157+
select(obj.__class__)
158+
.execution_options(include_deleted=True)
159+
.filter_by(**identity_dict)
160+
)
161+
existing = session.exec(stmt).scalars().one_or_none()
162+
if existing is None:
163+
continue
164+
if not existing.is_deleted:
165+
raise IntegrityError(
166+
"Duplicate primary key for active record",
167+
params=None,
168+
orig=None,
169+
)
170+
_copy_column_state(obj, existing)
171+
existing.is_deleted = False
172+
session.expunge(obj)
173+
174+
175+
@event.listens_for(SASession, "do_orm_execute")
176+
def _filter_soft_deleted(execute_state) -> None:
177+
if execute_state.is_select and not execute_state.execution_options.get("include_deleted", False):
178+
execute_state.statement = execute_state.statement.options(
179+
with_loader_criteria(
180+
SoftDeleteModel,
181+
lambda cls: _soft_delete_filter(cls),
182+
include_aliases=True,
183+
)
184+
)
185+
186+
187+
__all__ = ["BaseModel", "SoftDeleteModel"]

db/models.py

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import schemas
1111
from auth0.client import Auth0Client
12-
from db.core import BaseModel
12+
from db.core import SoftDeleteModel
1313
from db.types import (
1414
ApprovalStatusEnum,
1515
GroupMembershipData,
@@ -19,7 +19,7 @@
1919
from schemas.user import SessionUser
2020

2121

22-
class BiocommonsUser(BaseModel, table=True):
22+
class BiocommonsUser(SoftDeleteModel, table=True):
2323
__tablename__ = "biocommons_user"
2424
# Auth0 ID
2525
id: str = Field(primary_key=True)
@@ -50,13 +50,14 @@ def has_platform_membership(cls, user_id: str, platform_id: PlatformEnum, sessio
5050
"""
5151
Check if a user has a membership for a specific platform.
5252
"""
53-
return session.exec(
54-
select(PlatformMembership).where(
53+
result = session.exec(
54+
select(PlatformMembership.id).where(
5555
PlatformMembership.user_id == user_id,
5656
PlatformMembership.platform_id == platform_id,
5757
PlatformMembership.approval_status == ApprovalStatusEnum.APPROVED,
5858
)
59-
).exists()
59+
).first()
60+
return result is not None
6061

6162
@classmethod
6263
def create_from_auth0(cls, auth0_id: str, auth0_client: Auth0Client) -> Self:
@@ -87,6 +88,20 @@ def get_or_create(
8788
db_session.commit()
8889
return user
8990

91+
def delete(self, session: Session, commit: bool = False) -> "BiocommonsUser":
92+
"""
93+
Soft delete the user and cascade the soft delete to related memberships.
94+
"""
95+
for membership in list(self.platform_memberships or []):
96+
if not membership.is_deleted:
97+
membership.delete(session, commit=False)
98+
for membership in list(self.group_memberships or []):
99+
if not membership.is_deleted:
100+
membership.delete(session, commit=False)
101+
102+
super().delete(session, commit=commit)
103+
return self
104+
90105
def update_from_auth0(self, auth0_id: str, auth0_client: Auth0Client) -> Self:
91106
"""
92107
Fetch user data from Auth0 and update this object with it.
@@ -133,12 +148,12 @@ def add_group_membership(
133148
return membership
134149

135150

136-
class PlatformRoleLink(BaseModel, table=True):
151+
class PlatformRoleLink(SoftDeleteModel, table=True):
137152
platform_id: PlatformEnum = Field(primary_key=True, foreign_key="platform.id", sa_type=DbEnum(PlatformEnum, name="PlatformEnum"))
138153
role_id: str = Field(primary_key=True, foreign_key="auth0role.id")
139154

140155

141-
class Platform(BaseModel, table=True):
156+
class Platform(SoftDeleteModel, table=True):
142157
id: PlatformEnum = Field(primary_key=True, unique=True, sa_type=DbEnum(PlatformEnum, name="PlatformEnum"))
143158
# Human-readable name for the platform
144159
name: str = Field(unique=True)
@@ -168,8 +183,17 @@ def get_approved_by_user_id(cls, user_id: str, session: Session) -> list[Self] |
168183
.where(PlatformMembership.approval_status == ApprovalStatusEnum.APPROVED)
169184
).all()
170185

186+
def delete(self, session: Session, commit: bool = False) -> "Platform":
187+
memberships = list(self.members or [])
188+
for membership in memberships:
189+
if not membership.is_deleted:
190+
membership.delete(session, commit=False)
191+
192+
super().delete(session, commit=commit)
193+
return self
194+
171195

172-
class PlatformMembership(BaseModel, table=True):
196+
class PlatformMembership(SoftDeleteModel, table=True):
173197
__table_args__ = (
174198
UniqueConstraint("platform_id", "user_id", name="platform_user_id_platform_id"),
175199
)
@@ -224,6 +248,26 @@ def get_by_user_id_and_platform_id(cls, user_id: str, platform_id: PlatformEnum,
224248
)
225249
).one_or_none()
226250

251+
def delete(self, session: Session, commit: bool = False) -> "PlatformMembership":
252+
history_entries = session.exec(
253+
select(PlatformMembershipHistory)
254+
.where(
255+
PlatformMembershipHistory.user_id == self.user_id,
256+
PlatformMembershipHistory.platform_id == self.platform_id,
257+
)
258+
).all()
259+
260+
super().delete(session, commit=False)
261+
262+
for history in history_entries:
263+
if not history.is_deleted:
264+
history.delete(session, commit=False)
265+
266+
if commit:
267+
session.commit()
268+
session.expunge(self)
269+
return self
270+
227271
def save_history(self, session: Session) -> "PlatformMembershipHistory":
228272
# Make sure this object is in the session before accessing relationships
229273
if self not in session:
@@ -261,7 +305,7 @@ def get_data(self) -> PlatformMembershipData:
261305

262306

263307

264-
class PlatformMembershipHistory(BaseModel, table=True):
308+
class PlatformMembershipHistory(SoftDeleteModel, table=True):
265309
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
266310
platform_id: PlatformEnum = Field(sa_type=DbEnum(PlatformEnum, name="PlatformEnum"))
267311
user_id: str = Field(foreign_key="biocommons_user.id")
@@ -288,7 +332,7 @@ class PlatformMembershipHistory(BaseModel, table=True):
288332
)
289333

290334

291-
class GroupMembership(BaseModel, table=True):
335+
class GroupMembership(SoftDeleteModel, table=True):
292336
"""
293337
Stores the current approval status for a user/group pairing.
294338
Note: only one row per user/group, the approval history
@@ -351,15 +395,36 @@ def get_by_user_id_and_group_id(cls, user_id: str, group_id: str, session: Sessi
351395
)
352396
).one_or_none()
353397

398+
def delete(self, session: Session, commit: bool = False) -> "GroupMembership":
399+
history_entries = session.exec(
400+
select(GroupMembershipHistory)
401+
.where(
402+
GroupMembershipHistory.user_id == self.user_id,
403+
GroupMembershipHistory.group_id == self.group_id,
404+
)
405+
).all()
406+
407+
super().delete(session, commit=False)
408+
409+
for history in history_entries:
410+
if not history.is_deleted:
411+
history.delete(session, commit=False)
412+
413+
if commit:
414+
session.commit()
415+
session.expunge(self)
416+
return self
417+
354418
@classmethod
355419
def has_group_membership(cls, user_id: str, group_id: str, session: Session) -> bool:
356-
return session.exec(
357-
select(GroupMembership).where(
420+
result = session.exec(
421+
select(GroupMembership.id).where(
358422
GroupMembership.user_id == user_id,
359423
GroupMembership.group_id == group_id,
360424
GroupMembership.approval_status == ApprovalStatusEnum.APPROVED,
361425
)
362-
).exists()
426+
).first()
427+
return result is not None
363428

364429

365430

@@ -422,7 +487,7 @@ def get_data(self) -> GroupMembershipData:
422487

423488

424489

425-
class GroupMembershipHistory(BaseModel, table=True):
490+
class GroupMembershipHistory(SoftDeleteModel, table=True):
426491
"""
427492
Stores the full history of approval decisions for each user
428493
"""
@@ -478,12 +543,12 @@ def get_by_user_id(cls, user_id: str, session: Session) -> list[Self] | None:
478543
).all()
479544

480545

481-
class GroupRoleLink(BaseModel, table=True):
546+
class GroupRoleLink(SoftDeleteModel, table=True):
482547
group_id: str = Field(primary_key=True, foreign_key="biocommonsgroup.group_id")
483548
role_id: str = Field(primary_key=True, foreign_key="auth0role.id")
484549

485550

486-
class Auth0Role(BaseModel, table=True):
551+
class Auth0Role(SoftDeleteModel, table=True):
487552
id: str = Field(primary_key=True, unique=True)
488553
name: str
489554
description: str = Field(default="")
@@ -529,7 +594,7 @@ def get_or_create_by_name(
529594
return role
530595

531596

532-
class BiocommonsGroup(BaseModel, table=True):
597+
class BiocommonsGroup(SoftDeleteModel, table=True):
533598
# Name of the group / role name in Auth0, e.g. biocommons/group/tsi
534599
group_id: str = Field(primary_key=True, unique=True)
535600
# Human-readable name for the group
@@ -549,6 +614,15 @@ class BiocommonsGroup(BaseModel, table=True):
549614
def get_by_id(cls, group_id: str, session: Session) -> Self | None:
550615
return session.get(BiocommonsGroup, group_id)
551616

617+
def delete(self, session: Session, commit: bool = False) -> "BiocommonsGroup":
618+
memberships = list(self.members or [])
619+
for membership in memberships:
620+
if not membership.is_deleted:
621+
membership.delete(session, commit=False)
622+
623+
super().delete(session, commit=commit)
624+
return self
625+
552626
def get_admins(self, auth0_client: Auth0Client) -> set[str]:
553627
"""
554628
Get all admin emails for this group from the Auth0 API, returning a set of emails.

0 commit comments

Comments
 (0)