diff --git a/alembic/versions/a4bf1f58ce69_fix_journalist_association_in_replies.py b/alembic/versions/a4bf1f58ce69_fix_journalist_association_in_replies.py index ff2a2260ad..782b2f0ef6 100644 --- a/alembic/versions/a4bf1f58ce69_fix_journalist_association_in_replies.py +++ b/alembic/versions/a4bf1f58ce69_fix_journalist_association_in_replies.py @@ -18,7 +18,7 @@ def upgrade(): """ Fix reply association with journalist by updating journalist uuid to journalist id in the - journalist_id column. + journalist_id column for the replies and draftreplies tables. """ conn = op.get_bind() cursor = conn.execute(""" @@ -28,25 +28,47 @@ def upgrade(): """) replies_with_incorrect_associations = cursor.fetchall() - if not replies_with_incorrect_associations: - return - - conn.execute(""" - UPDATE replies - SET journalist_id= - ( - SELECT users.id - FROM users - WHERE journalist_id=users.uuid - ) - WHERE exists - ( - SELECT users.id - FROM users - WHERE journalist_id=users.uuid - ); + if replies_with_incorrect_associations: + conn.execute(""" + UPDATE replies + SET journalist_id= + ( + SELECT users.id + FROM users + WHERE journalist_id=users.uuid + ) + WHERE exists + ( + SELECT users.id + FROM users + WHERE journalist_id=users.uuid + ); + """) + + cursor = conn.execute(""" + SELECT journalist_id + FROM draftreplies, users + WHERE journalist_id=users.uuid; """) + draftreplies_with_incorrect_associations = cursor.fetchall() + if draftreplies_with_incorrect_associations: + conn.execute(""" + UPDATE draftreplies + SET journalist_id= + ( + SELECT users.id + FROM users + WHERE journalist_id=users.uuid + ) + WHERE exists + ( + SELECT users.id + FROM users + WHERE journalist_id=users.uuid + ); + """) + def downgrade(): """ diff --git a/tests/migrations/test_a4bf1f58ce69.py b/tests/migrations/test_a4bf1f58ce69.py index e40ede0e5c..28af0d562f 100644 --- a/tests/migrations/test_a4bf1f58ce69.py +++ b/tests/migrations/test_a4bf1f58ce69.py @@ -5,9 +5,9 @@ import subprocess from securedrop_client import db -from securedrop_client.db import Reply, User +from securedrop_client.db import DraftReply, Reply, User -from .utils import add_reply, add_source, add_user +from .utils import add_reply, add_draft_reply, add_source, add_user random.seed("=^..^=..^=..^=") @@ -28,7 +28,8 @@ def __init__(self, homedir): def load_data(self): """ - Load data that has the bug where user.uuid is stored in replies.journalist_id. + Load data that has the bug where user.uuid is stored in replies.journalist_id and + draftreplies.journalist_id. """ for _ in range(self.NUM_SOURCES): add_source(self.session) @@ -57,12 +58,20 @@ def load_data(self): source_id = random.randint(1, self.NUM_SOURCES) add_reply(self.session, journalist.uuid, source_id) + # Add draft replies from randomly-selected journalists to a randomly-selected sources + for _ in range(1, self.NUM_REPLIES): + journalist_id = random.randint(1, self.NUM_USERS) + journalist = self.session.query(User).filter_by(id=journalist_id).one() + source_id = random.randint(1, self.NUM_SOURCES) + add_draft_reply(self.session, journalist.uuid, source_id) + self.session.commit() def check_upgrade(self): """ - Make sure each reply in the replies table has the correct journalist_id stored for the - associated journalist by making sure a User account exists with that journalist id. + Make sure each reply in the replies and draftreplies tables have the correct journalist_id + stored for the associated journalist by making sure a User account exists with that + journalist id. """ replies = self.session.query(Reply).all() assert len(replies) @@ -71,6 +80,13 @@ def check_upgrade(self): # Will fail if User does not exist self.session.query(User).filter_by(id=reply.journalist_id).one() + draftreplies = self.session.query(DraftReply).all() + assert len(draftreplies) + + for draftreply in draftreplies: + # Will fail if User does not exist + self.session.query(User).filter_by(id=draftreply.journalist_id).one() + self.session.close() diff --git a/tests/migrations/utils.py b/tests/migrations/utils.py index 3cabf86fe6..650918a362 100644 --- a/tests/migrations/utils.py +++ b/tests/migrations/utils.py @@ -8,7 +8,7 @@ from sqlalchemy import text from sqlalchemy.orm.session import Session -from securedrop_client.db import DownloadError, Source +from securedrop_client.db import DownloadError, ReplySendStatus, Source random.seed("ᕕ( ᐛ )ᕗ") @@ -152,3 +152,48 @@ def add_reply(session: Session, journalist_id: int, source_id: int) -> None: ) """ session.execute(text(sql), params) + + +def add_draft_reply(session: Session, journalist_id: int, source_id: int) -> None: + reply_send_statuses = session.query(ReplySendStatus).all() + reply_send_status_ids = [reply_send_status.id for reply_send_status in reply_send_statuses] + + content = random_chars(1000) + + source = session.query(Source).filter_by(id=source_id).one() + + file_counter = len(source.collection) + 1 + + params = { + "uuid": str(uuid4()), + "journalist_id": journalist_id, + "source_id": source_id, + "file_counter": file_counter, + "content": content, + "send_status_id": random.choice(reply_send_status_ids), + "timestamp": random_datetime(), + } + + sql = """ + INSERT INTO draftreplies + ( + uuid, + journalist_id, + source_id, + file_counter, + content, + send_status_id, + timestamp + ) + VALUES + ( + :uuid, + :journalist_id, + :source_id, + :file_counter, + :content, + :send_status_id, + :timestamp + ) + """ + session.execute(text(sql), params)