From 67c729cd26cd10b92fbd293769042bf7e019849e Mon Sep 17 00:00:00 2001 From: Allie Crevier Date: Wed, 21 Oct 2020 18:30:31 -0700 Subject: [PATCH] add seen tables and test migration --- .../versions/bd57477f19a2_add_seen_tables.py | 56 +++++ securedrop_client/db.py | 32 +++ tests/migrations/test_bd57477f19a2.py | 233 ++++++++++++++++++ tests/migrations/utils.py | 152 +++++++++++- tests/test_alembic.py | 2 +- 5 files changed, 473 insertions(+), 2 deletions(-) create mode 100644 alembic/versions/bd57477f19a2_add_seen_tables.py create mode 100644 tests/migrations/test_bd57477f19a2.py diff --git a/alembic/versions/bd57477f19a2_add_seen_tables.py b/alembic/versions/bd57477f19a2_add_seen_tables.py new file mode 100644 index 0000000000..b6caf34dda --- /dev/null +++ b/alembic/versions/bd57477f19a2_add_seen_tables.py @@ -0,0 +1,56 @@ +"""add seen tables + +Revision ID: bd57477f19a2 +Revises: a4bf1f58ce69 +Create Date: 2020-10-20 22:43:46.743035 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'bd57477f19a2' +down_revision = 'a4bf1f58ce69' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('seen_files', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('file_id', sa.Integer(), nullable=False), + sa.Column('journalist_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['file_id'], ['files.id'], name=op.f('fk_seen_files_file_id_files')), + sa.ForeignKeyConstraint(['journalist_id'], ['users.id'], name=op.f('fk_seen_files_journalist_id_users')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_seen_files')), + sa.UniqueConstraint('file_id', 'journalist_id', name=op.f('uq_seen_files_file_id')) + ) + op.create_table('seen_messages', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('message_id', sa.Integer(), nullable=False), + sa.Column('journalist_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['journalist_id'], ['users.id'], name=op.f('fk_seen_messages_journalist_id_users')), + sa.ForeignKeyConstraint(['message_id'], ['messages.id'], name=op.f('fk_seen_messages_message_id_messages')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_seen_messages')), + sa.UniqueConstraint('message_id', 'journalist_id', name=op.f('uq_seen_messages_message_id')) + ) + op.create_table('seen_replies', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('reply_id', sa.Integer(), nullable=False), + sa.Column('journalist_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['journalist_id'], ['users.id'], name=op.f('fk_seen_replies_journalist_id_users')), + sa.ForeignKeyConstraint(['reply_id'], ['replies.id'], name=op.f('fk_seen_replies_reply_id_replies')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_seen_replies')), + sa.UniqueConstraint('reply_id', 'journalist_id', name=op.f('uq_seen_replies_reply_id')) + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('seen_replies') + op.drop_table('seen_messages') + op.drop_table('seen_files') + # ### end Alembic commands ### diff --git a/securedrop_client/db.py b/securedrop_client/db.py index 455c7669dd..f4b2868709 100644 --- a/securedrop_client/db.py +++ b/securedrop_client/db.py @@ -490,3 +490,35 @@ def initials(self) -> str: return self.lastname[0:2].lower() else: return self.username[0:2].lower() # username must be at least 3 characters + + +class SeenFile(Base): + __tablename__ = "seen_files" + __table_args__ = (UniqueConstraint("file_id", "journalist_id"),) + id = Column(Integer, primary_key=True) + file_id = Column(Integer, ForeignKey("files.id"), nullable=False) + journalist_id = Column(Integer, ForeignKey("users.id"), nullable=True) + file = relationship("File", backref=backref("seen_files", lazy="dynamic", cascade="all,delete")) + journalist = relationship("User", backref=backref("seen_files")) + + +class SeenMessage(Base): + __tablename__ = "seen_messages" + __table_args__ = (UniqueConstraint("message_id", "journalist_id"),) + id = Column(Integer, primary_key=True) + message_id = Column(Integer, ForeignKey("messages.id"), nullable=False) + journalist_id = Column(Integer, ForeignKey("users.id"), nullable=True) + message = relationship( + "Message", backref=backref("seen_messages", lazy="dynamic", cascade="all,delete") + ) + journalist = relationship("User", backref=backref("seen_messages")) + + +class SeenReply(Base): + __tablename__ = "seen_replies" + __table_args__ = (UniqueConstraint("reply_id", "journalist_id"),) + id = Column(Integer, primary_key=True) + reply_id = Column(Integer, ForeignKey("replies.id"), nullable=False) + journalist_id = Column(Integer, ForeignKey("users.id"), nullable=True) + reply = relationship("Reply", backref=backref("seen_replies", cascade="all,delete")) + journalist = relationship("User", backref=backref("seen_replies")) diff --git a/tests/migrations/test_bd57477f19a2.py b/tests/migrations/test_bd57477f19a2.py new file mode 100644 index 0000000000..cd83e80c65 --- /dev/null +++ b/tests/migrations/test_bd57477f19a2.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- + +import os +import random +import subprocess + +import pytest +from sqlalchemy import text +from sqlalchemy.exc import IntegrityError + +from securedrop_client import db +from securedrop_client.db import Reply, User + +from .utils import ( + add_file, + add_message, + add_reply, + add_source, + add_user, + mark_file_as_seen, + mark_message_as_seen, + mark_reply_as_seen, +) + + +class UpgradeTester: + """ + Verify that upgrading to the target migration results in the creation of the seen tables. + """ + + NUM_USERS = 20 + NUM_SOURCES = 20 + NUM_REPLIES = 40 + + def __init__(self, homedir): + subprocess.check_call(["sqlite3", os.path.join(homedir, "svs.sqlite"), ".databases"]) + self.session = db.make_session_maker(homedir)() + + def load_data(self): + for source_id in range(1, self.NUM_SOURCES + 1): + add_source(self.session) + + # Add zero to a few messages from each source, some messages are set to downloaded + for _ in range(random.randint(0, 2)): + add_message(self.session, source_id) + + # Add zero to a few files from each source, some files are set to downloaded + for i in range(random.randint(0, 2)): + add_file(self.session, source_id) + + self.session.commit() + + for i in range(self.NUM_USERS): + if i == 0: + # As of this migration, the server tells the client that the associated journalist + # of a reply has been deleted by returning "deleted" as the uuid of the associated + # journalist. This gets stored as the jouranlist_id in the replies table. + # + # Make sure to test this case as well. + add_user(self.session, "deleted") + source_id = random.randint(1, self.NUM_SOURCES) + user = self.session.query(User).filter_by(uuid="deleted").one() + add_reply(self.session, user.id, source_id) + else: + add_user(self.session) + + self.session.commit() + + # Add replies from randomly-selected journalists to a randomly-selected sources + for _ in range(1, self.NUM_REPLIES): + journalist_id = random.randint(1, self.NUM_USERS) + source_id = random.randint(1, self.NUM_SOURCES) + add_reply(self.session, journalist_id, source_id) + + self.session.commit() + + def check_upgrade(self): + """ + Make sure seen tables exist and work as expected. + """ + replies = self.session.query(Reply).all() + assert len(replies) + + for reply in replies: + # Will fail if User does not exist + self.session.query(User).filter_by(id=reply.journalist_id).one() + + sql = "SELECT * FROM files" + files = self.session.execute(text(sql)).fetchall() + + sql = "SELECT * FROM messages" + messages = self.session.execute(text(sql)).fetchall() + + sql = "SELECT * FROM replies" + replies = self.session.execute(text(sql)).fetchall() + + # Now seen tables exist, so you should be able to mark some files, messages, and replies + # as seen + for file in files: + if random.choice([0, 1]): + selected_journo_id = random.randint(1, self.NUM_USERS) + mark_file_as_seen(self.session, file.id, selected_journo_id) + for message in messages: + if random.choice([0, 1]): + selected_journo_id = random.randint(1, self.NUM_USERS) + mark_message_as_seen(self.session, message.id, selected_journo_id) + for reply in replies: + if random.choice([0, 1]): + selected_journo_id = random.randint(1, self.NUM_USERS) + mark_reply_as_seen(self.session, reply.id, selected_journo_id) + + # Check unique constraint on (reply_id, journalist_id) + params = {"reply_id": 100, "journalist_id": 100} + sql = """ + INSERT INTO seen_replies (reply_id, journalist_id) + VALUES (:reply_id, :journalist_id); + """ + self.session.execute(text(sql), params) + with pytest.raises(IntegrityError): + self.session.execute(text(sql), params) + + # Check unique constraint on (message_id, journalist_id) + params = {"message_id": 100, "journalist_id": 100} + sql = """ + INSERT INTO seen_messages (message_id, journalist_id) + VALUES (:message_id, :journalist_id); + """ + self.session.execute(text(sql), params) + with pytest.raises(IntegrityError): + self.session.execute(text(sql), params) + + # Check unique constraint on (file_id, journalist_id) + params = {"file_id": 101, "journalist_id": 100} + sql = """ + INSERT INTO seen_files (file_id, journalist_id) + VALUES (:file_id, :journalist_id); + """ + self.session.execute(text(sql), params) + with pytest.raises(IntegrityError): + self.session.execute(text(sql), params) + + +class DowngradeTester: + """ + Verify that downgrading from the target migration keeps in place the updates from the migration + since there is no need to add bad data back into the db (the migration is backwards compatible). + """ + + NUM_USERS = 20 + NUM_SOURCES = 20 + NUM_REPLIES = 40 + + def __init__(self, homedir): + subprocess.check_call(["sqlite3", os.path.join(homedir, "svs.sqlite"), ".databases"]) + self.session = db.make_session_maker(homedir)() + + def load_data(self): + for source_id in range(1, self.NUM_SOURCES + 1): + add_source(self.session) + + # Add zero to a few messages from each source, some messages are set to downloaded + for _ in range(random.randint(0, 3)): + add_message(self.session, source_id) + + # Add zero to a few files from each source, some files are set to downloaded + for i in range(random.randint(0, 3)): + add_file(self.session, source_id) + + self.session.commit() + + for i in range(self.NUM_USERS): + if i == 0: + # As of this migration, the server tells the client that the associated journalist + # of a reply has been deleted by returning "deleted" as the uuid of the associated + # journalist. This gets stored as the jouranlist_id in the replies table. + # + # Make sure to test this case as well. + add_user(self.session, "deleted") + source_id = random.randint(1, self.NUM_SOURCES) + add_reply(self.session, "deleted", source_id) + else: + add_user(self.session) + + self.session.commit() + + # Add replies from randomly-selected journalists to a randomly-selected sources + for _ in range(1, self.NUM_REPLIES): + journalist_id = random.randint(1, self.NUM_USERS) + source_id = random.randint(1, self.NUM_SOURCES) + add_reply(self.session, journalist_id, source_id) + + self.session.commit() + + # Mark some files, messages, and replies as seen + sql = "SELECT * FROM files" + files = self.session.execute(text(sql)).fetchall() + for file in files: + if random.choice([0, 1]): + selected_journo_id = random.randint(1, self.NUM_USERS) + mark_file_as_seen(self.session, file.id, selected_journo_id) + + sql = "SELECT * FROM messages" + messages = self.session.execute(text(sql)).fetchall() + for message in messages: + if random.choice([0, 1]): + selected_journo_id = random.randint(1, self.NUM_USERS) + mark_message_as_seen(self.session, message.id, selected_journo_id) + + sql = "SELECT * FROM replies" + replies = self.session.execute(text(sql)).fetchall() + for reply in replies: + if random.choice([0, 1]): + selected_journo_id = random.randint(1, self.NUM_USERS) + mark_reply_as_seen(self.session, reply.id, selected_journo_id) + + self.session.commit() + + def check_downgrade(self): + """ + Check that seen tables no longer exist. + """ + params = {"table_name": "seen_files"} + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;" + seen_files_exists = self.session.execute(text(sql), params).fetchall() + assert not seen_files_exists + params = {"table_name": "seen_messages"} + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;" + seen_messages_exists = self.session.execute(text(sql), params).fetchall() + assert not seen_messages_exists + params = {"table_name": "seen_replies"} + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;" + seen_replies_exists = self.session.execute(text(sql), params).fetchall() + assert not seen_replies_exists diff --git a/tests/migrations/utils.py b/tests/migrations/utils.py index 3f37ed0b1a..4d2f6628ce 100644 --- a/tests/migrations/utils.py +++ b/tests/migrations/utils.py @@ -6,6 +6,7 @@ from uuid import uuid4 from sqlalchemy import text +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.session import Session from securedrop_client.db import DownloadError, Source @@ -92,6 +93,111 @@ def add_user(session: Session, uuid: Optional[str] = None) -> None: session.execute(text(sql), params) +def add_file(session: Session, source_id: int) -> None: + is_downloaded = random_bool() + is_decrypted = random_bool() if is_downloaded else None + + source = session.query(Source).filter_by(id=source_id).one() + file_counter = len(source.collection) + 1 + + params = { + "uuid": str(uuid4()), + "source_id": source_id, + "filename": random_chars(50) + "-doc.gz.gpg", + "file_counter": file_counter, + "size": random.randint(0, 1024 * 1024 * 500), + "download_url": random_chars(50), + "is_downloaded": is_downloaded, + "is_decrypted": is_decrypted, + "is_read": random.choice([True, False]), + "last_updated": random_datetime(), + } + sql = """ + INSERT INTO files + ( + uuid, + source_id, + filename, + file_counter, + size, + download_url, + is_downloaded, + is_decrypted, + is_read, + last_updated + ) + VALUES + ( + :uuid, + :source_id, + :filename, + :file_counter, + :size, + :download_url, + :is_downloaded, + :is_decrypted, + :is_read, + :last_updated + ) + """ + session.execute(text(sql), params) + + +def add_message(session: Session, source_id: int) -> None: + is_downloaded = random_bool() + is_decrypted = random_bool() if is_downloaded else None + + content = random_chars(1000) if is_downloaded else None + + source = session.query(Source).filter_by(id=source_id).one() + file_counter = len(source.collection) + 1 + + params = { + "uuid": str(uuid4()), + "source_id": source_id, + "filename": random_chars(50) + "-doc.gz.gpg", + "file_counter": file_counter, + "size": random.randint(0, 1024 * 1024 * 500), + "content": content, + "download_url": random_chars(50), + "is_downloaded": is_downloaded, + "is_decrypted": is_decrypted, + "is_read": random.choice([True, False]), + "last_updated": random_datetime(), + } + sql = """ + INSERT INTO messages + ( + uuid, + source_id, + filename, + file_counter, + size, + content, + download_url, + is_downloaded, + is_decrypted, + is_read, + last_updated + ) + VALUES + ( + :uuid, + :source_id, + :filename, + :file_counter, + :size, + :content, + :download_url, + :is_downloaded, + :is_decrypted, + :is_read, + :last_updated + ) + """ + session.execute(text(sql), params) + + def add_reply(session: Session, journalist_id: int, source_id: int) -> None: is_downloaded = random_bool() if random_bool() else None is_decrypted = random_bool() if is_downloaded else None @@ -102,7 +208,6 @@ def add_reply(session: Session, journalist_id: int, source_id: int) -> None: content = random_chars(1000) if is_downloaded else None source = session.query(Source).filter_by(id=source_id).one() - file_counter = len(source.collection) + 1 params = { @@ -149,3 +254,48 @@ def add_reply(session: Session, journalist_id: int, source_id: int) -> None: ) """ session.execute(text(sql), params) + + +def mark_file_as_seen(session: Session, file_id: int, journalist_id: int) -> None: + params = { + "file_id": file_id, + "journalist_id": journalist_id, + } + sql = """ + INSERT INTO seen_files (file_id, journalist_id) + VALUES (:file_id, :journalist_id) + """ + try: + session.execute(text(sql), params) + except IntegrityError: + pass + + +def mark_message_as_seen(session: Session, message_id: int, journalist_id: int) -> None: + params = { + "message_id": message_id, + "journalist_id": journalist_id, + } + sql = """ + INSERT INTO seen_messages (message_id, journalist_id) + VALUES (:message_id, :journalist_id) + """ + try: + session.execute(text(sql), params) + except IntegrityError: + pass + + +def mark_reply_as_seen(session: Session, reply_id: int, journalist_id: int): + params = { + "reply_id": reply_id, + "journalist_id": journalist_id, + } + sql = """ + INSERT INTO seen_replies (reply_id, journalist_id) + VALUES (:reply_id, :journalist_id) + """ + try: + session.execute(text(sql), params) + except IntegrityError: + pass diff --git a/tests/test_alembic.py b/tests/test_alembic.py index 2a406dc0f5..d3a8cdc4f6 100644 --- a/tests/test_alembic.py +++ b/tests/test_alembic.py @@ -20,7 +20,7 @@ x.split(".")[0].split("_")[0] for x in os.listdir(MIGRATION_PATH) if x.endswith(".py") ] -DATA_MIGRATIONS = ["a4bf1f58ce69"] +DATA_MIGRATIONS = ["a4bf1f58ce69", "bd57477f19a2"] WHITESPACE_REGEX = re.compile(r"\s+")