Skip to content

Commit

Permalink
Smooth more rough edges db models (#106)
Browse files Browse the repository at this point in the history
* build: Bump alembic, v1.11 => v1.13

* feat: Move study id fkeys into own cols

* feat: Add db migration for updated study fkeys

* fix: Update dupe logic/names for new db model

* fix: Update data extract api for db model changes

* fix: Update data schemas for db model change
  • Loading branch information
bdewilde authored Apr 29, 2024
1 parent b6f6b60 commit ed7f74d
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 48 deletions.
15 changes: 10 additions & 5 deletions colandr/apis/resources/data_extractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def get(self, id):
"""get data extraction record for a single study by id"""
current_user = jwtext.get_current_user()
# check current user authorization
extracted_data = db.session.get(DataExtraction, id)
extracted_data = db.session.execute(
sa.select(DataExtraction).filter_by(study_id=id)
).scalar_one_or_none()
if not extracted_data:
return not_found_error(f"<DataExtraction(study_id={id})> not found")
# TODO: figure out if this is "better" approach
Expand Down Expand Up @@ -103,7 +105,9 @@ def delete(self, id, labels):
"""delete data extraction record for a single study by id"""
current_user = jwtext.get_current_user()
# check current user authorization
extracted_data = db.session.get(DataExtraction, id)
extracted_data = db.session.execute(
sa.select(DataExtraction).filter_by(study_id=id)
).scalar_one_or_none()
if not extracted_data:
return not_found_error(f"<DataExtraction(study_id={id})> not found")
if (
Expand Down Expand Up @@ -154,11 +158,12 @@ def put(self, args, id):
"""modify data extraction record for a single study by id"""
current_user = jwtext.get_current_user()
# check current user authorization
extracted_data = db.session.get(DataExtraction, id)
assert extracted_data is not None # type guard
review_id = extracted_data.review_id
extracted_data = db.session.execute(
sa.select(DataExtraction).filter_by(study_id=id)
).scalar_one_or_none()
if not extracted_data:
return not_found_error(f"<DataExtraction(study_id={id})> not found")
review_id = extracted_data.review_id
if (
current_user.review_user_assoc.filter_by(review_id=review_id).one_or_none()
is None
Expand Down
2 changes: 2 additions & 0 deletions colandr/apis/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ImportSchema(Schema):
class DedupeSchema(Schema):
id = fields.Int(dump_only=True)
created_at = fields.DateTime(dump_only=True, format="iso")
study_id = fields.Int(required=True, validate=Range(min=1, max=constants.MAX_INT))
review_id = fields.Int(required=True, validate=Range(min=1, max=constants.MAX_INT))
duplicate_of = fields.Int(
load_default=None, validate=Range(min=1, max=constants.MAX_BIGINT)
Expand Down Expand Up @@ -211,6 +212,7 @@ class DataExtractionSchema(Schema):
id = fields.Int(dump_only=True)
created_at = fields.DateTime(dump_only=True, format="iso")
updated_at = fields.DateTime(dump_only=True, format="iso")
study_id = fields.Int(required=True, validate=Range(min=1, max=constants.MAX_INT))
review_id = fields.Int(required=True, validate=Range(min=1, max=constants.MAX_INT))
extracted_items = fields.Nested(ExtractedItem, many=True)

Expand Down
51 changes: 33 additions & 18 deletions colandr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,11 @@ class DataSource(db.Model):
__tablename__ = "data_sources"
__table_args__ = (
db.UniqueConstraint(
"source_type", "source_name", name="source_type_source_name_uc"
"source_type",
"source_name",
"source_url",
name="uq_source_type_name_url",
postgresql_nulls_not_distinct=True,
),
)

Expand Down Expand Up @@ -580,13 +584,17 @@ class Dedupe(db.Model):
__tablename__ = "dedupes"

# columns
id: M[int] = mapcol(
sa.BigInteger, sa.ForeignKey("studies.id", ondelete="CASCADE"), primary_key=True
)
id: M[int] = mapcol(sa.BigInteger, primary_key=True, autoincrement=True)
created_at: M[datetime.datetime] = mapcol(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
)
study_id: M[int] = mapcol(
sa.BigInteger,
sa.ForeignKey("studies.id", ondelete="CASCADE"),
index=True,
unique=True,
)
review_id: M[int] = mapcol(
sa.Integer, sa.ForeignKey("reviews.id", ondelete="CASCADE"), index=True
)
Expand All @@ -597,29 +605,27 @@ class Dedupe(db.Model):

# relationships
study: M["Study"] = sa_orm.relationship(
"Study", foreign_keys=[id], back_populates="dedupe", lazy="select"
"Study", foreign_keys=[study_id], back_populates="dedupe", lazy="select"
)
review: M["Review"] = sa_orm.relationship(
"Review", foreign_keys=[review_id], back_populates="dedupes", lazy="select"
)

def __init__(self, id_, review_id, duplicate_of, duplicate_score):
self.id = id_
def __init__(self, study_id, review_id, duplicate_of, duplicate_score):
self.study_id = study_id
self.review_id = review_id
self.duplicate_of = duplicate_of
self.duplicate_score = duplicate_score

def __repr__(self):
return f"<Dedupe(study_id={self.id})>"
return f"<Dedupe(study_id={self.study_id})>"


class DataExtraction(db.Model):
__tablename__ = "data_extractions"

# columns
id: M[int] = mapcol(
sa.BigInteger, sa.ForeignKey("studies.id", ondelete="CASCADE"), primary_key=True
)
id: M[int] = mapcol(sa.BigInteger, primary_key=True, autoincrement=True)
created_at: M[datetime.datetime] = mapcol(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
Expand All @@ -630,6 +636,12 @@ class DataExtraction(db.Model):
server_default=sa.func.now(),
server_onupdate=sa.FetchedValue(),
)
study_id: M[int] = mapcol(
sa.BigInteger,
sa.ForeignKey("studies.id", ondelete="CASCADE"),
index=True,
unique=True,
)
review_id: M[int] = mapcol(
sa.Integer,
sa.ForeignKey("reviews.id", ondelete="CASCADE"),
Expand All @@ -641,7 +653,10 @@ class DataExtraction(db.Model):

# relationships
study: M["Study"] = sa_orm.relationship(
"Study", foreign_keys=[id], back_populates="data_extraction", lazy="select"
"Study",
foreign_keys=[study_id],
back_populates="data_extraction",
lazy="select",
)
review: M["Review"] = sa_orm.relationship(
"Review",
Expand All @@ -650,13 +665,13 @@ class DataExtraction(db.Model):
lazy="select",
)

def __init__(self, id_, review_id, extracted_items=None):
self.id = id_
def __init__(self, study_id, review_id, extracted_items=None):
self.study_id = study_id
self.review_id = review_id
self.extracted_items = extracted_items

def __repr__(self):
return f"<DataExtraction(study_id={self.id})>"
return f"<DataExtraction(study_id={self.study_id})>"


# EVENTS
Expand Down Expand Up @@ -773,18 +788,18 @@ def update_study_status(mapper, connection, target):
elif stage == "fulltext":
# we may have to insert or delete a corresponding data extraction record
data_extraction = connection.execute(
sa.select(DataExtraction).where(DataExtraction.id == study_id)
sa.select(DataExtraction).where(DataExtraction.study_id == study_id)
).first()
# data_extraction_inserted_or_deleted = False
if status == "included" and data_extraction is None:
connection.execute(
sa.insert(DataExtraction).values(id=study_id, review_id=review_id)
sa.insert(DataExtraction).values(study_id=study_id, review_id=review_id)
)
LOGGER.info("inserted <DataExtraction(study_id=%s)>", study_id)
# data_extraction_inserted_or_deleted = True
elif status != "included" and data_extraction is not None:
connection.execute(
sa.delete(DataExtraction).where(DataExtraction.id == study_id)
sa.delete(DataExtraction).where(DataExtraction.study_id == study_id)
)
LOGGER.info("deleted <DataExtraction(study_id=%s)>", study_id)
# data_extraction_inserted_or_deleted = True
Expand Down
44 changes: 21 additions & 23 deletions colandr/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,25 @@ def deduplicate_citations(review_id: int):

# get *all* citation ids for this review, as well as included/excluded
stmt = sa.select(models.Study.id).where(models.Study.review_id == review_id)
all_cids = set(db.session.execute(stmt).scalars().all())
all_sids = set(db.session.execute(stmt).scalars().all())
stmt = (
sa.select(models.Study.id)
.where(models.Study.review_id == review_id)
# .where(models.Study.citation_status.in_(["included", "excluded"]))
.where(models.Study.citation_status == sa.any_(["included", "excluded"]))
)
incl_excl_cids = set(db.session.execute(stmt).scalars().all())
incl_excl_sids = set(db.session.execute(stmt).scalars().all())

duplicate_cids = set()
duplicate_sids = set()
studies_to_update = []
dedupes_to_insert = []
for cids, scores in clustered_dupes:
int_cids = [int(cid) for cid in cids] # convert from numpy.int64
cid_scores = {cid: float(score) for cid, score in zip(int_cids, scores)}
for sids, scores in clustered_dupes:
int_sids = [int(sid) for sid in sids] # convert from numpy.int64
sid_scores = {sid: float(score) for sid, score in zip(int_sids, scores)}
# already an in/excluded citation in this dupe cluster?
# take the first one to be "canonical"
if any(cid in incl_excl_cids for cid in int_cids):
canonical_citation_id = sorted(set(int_cids).intersection(incl_excl_cids))[
0
]
if any(sid in incl_excl_sids for sid in int_sids):
canonical_study_id = sorted(set(int_sids).intersection(incl_excl_sids))[0]
# otherwise, take the "most complete" citation in the cluster as "canonical"
else:
stmt = (
Expand All @@ -173,30 +171,30 @@ def deduplicate_citations(review_id: int):
).label("n_null_cols"),
)
.where(models.Study.review_id == review_id)
# .where(Citation.id.in_(int_cids))
.where(models.Study.id == sa.any_(int_cids))
# .where(models.Study.id.in_(int_sids))
.where(models.Study.id == sa.any_(int_sids))
.order_by(sa.text("n_null_cols ASC"))
.limit(1)
)
result = db.session.execute(stmt).first()
assert result is not None
canonical_citation_id = result.id
canonical_study_id = result.id

for cid, score in cid_scores.items():
if cid != canonical_citation_id:
duplicate_cids.add(cid)
studies_to_update.append({"id": cid, "dedupe_status": "duplicate"})
for sid, score in sid_scores.items():
if sid != canonical_study_id:
duplicate_sids.add(sid)
studies_to_update.append({"id": sid, "dedupe_status": "duplicate"})
dedupes_to_insert.append(
{
"id": cid,
"study_id": sid,
"review_id": review_id,
"duplicate_of": canonical_citation_id,
"duplicate_of": canonical_study_id,
"duplicate_score": score,
}
)
non_duplicate_cids = all_cids - duplicate_cids
non_duplicate_sids = all_sids - duplicate_sids
studies_to_update.extend(
{"id": cid, "dedupe_status": "not_duplicate"} for cid in non_duplicate_cids
{"id": sid, "dedupe_status": "not_duplicate"} for sid in non_duplicate_sids
)

db.session.execute(sa.update(models.Study), studies_to_update)
Expand All @@ -205,8 +203,8 @@ def deduplicate_citations(review_id: int):
LOGGER.info(
"<Review(id=%s)>: found %s duplicate and %s non-duplicate citations",
review_id,
len(duplicate_cids),
len(non_duplicate_cids),
len(duplicate_sids),
len(non_duplicate_sids),
)

lock.release()
Expand Down
91 changes: 91 additions & 0 deletions migrations/versions/ff8fa67b9273_stop_sharing_study_pkeys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""stop sharing study pkeys
Revision ID: ff8fa67b9273
Revises: 6899968b51c0
Create Date: 2024-04-27 23:44:58.876593
"""

import sqlalchemy as sa
from alembic import op


# revision identifiers, used by Alembic.
revision = "ff8fa67b9273"
down_revision = "6899968b51c0"
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table(
"data_extractions", schema=None, recreate="always"
) as batch_op:
batch_op.add_column(
sa.Column("study_id", sa.BigInteger(), nullable=True),
insert_before="review_id",
)
op.execute("UPDATE data_extractions SET study_id = id")
with op.batch_alter_table("data_extractions", schema=None) as batch_op:
batch_op.alter_column("study_id", existing_type=sa.BigInteger(), nullable=False)
batch_op.create_index(
batch_op.f("ix_data_extractions_study_id"), ["study_id"], unique=True
)
batch_op.drop_constraint("data_extractions_id_fkey", type_="foreignkey")
batch_op.create_foreign_key(
"data_extractions_study_id_fkey",
"studies",
["study_id"],
["id"],
ondelete="CASCADE",
)

with op.batch_alter_table("dedupes", schema=None, recreate="always") as batch_op:
batch_op.add_column(
sa.Column("study_id", sa.BigInteger(), nullable=True),
insert_before="review_id",
)
op.execute("UPDATE dedupes SET study_id = id")
with op.batch_alter_table("dedupes", schema=None) as batch_op:
batch_op.alter_column("study_id", existing_type=sa.BigInteger(), nullable=False)
batch_op.create_index(
batch_op.f("ix_dedupes_study_id"), ["study_id"], unique=True
)
batch_op.drop_constraint("dedupes_id_fkey", type_="foreignkey")
batch_op.create_foreign_key(
"dedupes_study_id_fkey", "studies", ["study_id"], ["id"], ondelete="CASCADE"
)

with op.batch_alter_table("data_sources", schema=None) as batch_op:
batch_op.drop_constraint("source_type_source_name_uc", type_="unique")
batch_op.create_unique_constraint(
"uq_source_type_name_url",
["source_type", "source_name", "source_url"],
postgresql_nulls_not_distinct=True,
)


def downgrade():
with op.batch_alter_table("data_sources", schema=None) as batch_op:
batch_op.drop_constraint("uq_source_type_name_url", type_="unique")
batch_op.create_unique_constraint(
"source_type_source_name_uc", ["source_type", "source_name"]
)

op.execute("UPDATE dedupes SET id = study_id")
with op.batch_alter_table("dedupes", schema=None) as batch_op:
batch_op.drop_constraint("dedupes_study_id_fkey", type_="foreignkey")
batch_op.create_foreign_key(
"dedupes_id_fkey", "studies", ["id"], ["id"], ondelete="CASCADE"
)
batch_op.drop_index(batch_op.f("ix_dedupes_study_id"))
batch_op.drop_column("study_id")

op.execute("UPDATE data_extractions SET id = study_id")
with op.batch_alter_table("data_extractions", schema=None) as batch_op:
batch_op.drop_constraint("data_extractions_study_id_fkey", type_="foreignkey")
batch_op.create_foreign_key(
"data_extractions_id_fkey", "studies", ["id"], ["id"], ondelete="CASCADE"
)
batch_op.drop_index(batch_op.f("ix_data_extractions_study_id"))
batch_op.drop_column("study_id")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
]
dependencies = [
"alembic~=1.11.0",
"alembic~=1.13.0",
"arrow~=1.3.0",
"bibtexparser~=1.4.0",
"celery~=5.3.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements/prod.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: keep aligned with pyproject.toml
# TODO: get rid of this once pip allows for dep-only installs from pyproject.toml
alembic~=1.11.0
alembic~=1.13.0
arrow~=1.3.0
bibtexparser~=1.4.0
celery~=5.3.0
Expand Down

0 comments on commit ed7f74d

Please sign in to comment.