Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Allie Crevier committed Nov 4, 2020
1 parent 15ba68a commit 397998b
Show file tree
Hide file tree
Showing 9 changed files with 850 additions and 108 deletions.
21 changes: 21 additions & 0 deletions securedrop_client/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ def seen(self) -> bool:

return False

def seen_by(self, journalist_id: int) -> bool:
for seen_message in self.seen_messages:
if seen_message.journalist_id == journalist_id:
return True

return False


class File(Base):

Expand Down Expand Up @@ -294,6 +301,13 @@ def seen(self) -> bool:

return False

def seen_by(self, journalist_id: int) -> bool:
for seen_file in self.seen_files:
if seen_file.journalist_id == journalist_id:
return True

return False


class Reply(Base):

Expand Down Expand Up @@ -387,6 +401,13 @@ def seen(self) -> bool:
"""
return True

def seen_by(self, journalist_id: int) -> bool:
for seen_reply in self.seen_replies:
if seen_reply.journalist_id == journalist_id:
return True

return False


class DownloadErrorCodes(Enum):
"""
Expand Down
15 changes: 9 additions & 6 deletions securedrop_client/gui/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,25 +1177,28 @@ def _on_authentication_changed(self, authenticated: bool) -> None:

@pyqtSlot(str)
def _on_mark_seen(self, source_uuid: str):
"""
Immediately show the source widget as having been seen and tell the controller to make a
seen API request to mark all files, messages, and replies as unseen by the current user as
seen.
"""
if self.source_uuid != source_uuid:
return

# immediately update styles to mark as seen
self.seen = True
self.update_styles()

# Prepare the lists of uuids to mark as seen by the current user. Continue to process the
# next item if the source conversation item has already been seen by the current user or if
# it no longer exists.
try:
if self.source.seen:
return

# Prepare the lists of uuids to mark as seen. Continue if one of the source conversation
# items no longer exists so that the rest of the items will be marked as seen.
files = []
messages = []
replies = []
source_items = self.source.collection
for item in source_items:
if item.seen:
if item.seen_by(self.controller.authenticated_user.id):
continue

try:
Expand Down
130 changes: 79 additions & 51 deletions securedrop_client/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,38 +257,10 @@ def __update_submissions(
lazy_setattr(local_submission, "is_read", submission.is_read)
lazy_setattr(local_submission, "download_url", submission.download_url)

# Add seen record if one doesn't yet exist
for journalist_uuid in submission.seen_by:
journalist = session.query(User).filter_by(uuid=journalist_uuid).one_or_none()

# Do not add seen record if journalist is missing from the local db. If the
# journalist accont was deleted, wait until the server says so.
if not journalist:
return

journalist_id = journalist.id
if model == File:
seen_file = (
session.query(SeenFile)
.filter_by(file_id=local_submission.id, journalist_id=journalist_id)
.one_or_none()
)
if not seen_file:
seen_file = SeenFile(
file_id=local_submission.id, journalist_id=journalist_id
)
session.add(seen_file)
elif model == Message:
seen_message = (
session.query(SeenMessage)
.filter_by(message_id=local_submission.id, journalist_id=journalist_id)
.one_or_none()
)
if not seen_message:
seen_message = SeenMessage(
message_id=local_submission.id, journalist_id=journalist_id
)
session.add(seen_message)
if model == File:
add_seen_file_records(local_submission.id, submission.seen_by, session)
elif model == Message:
add_seen_message_records(local_submission.id, submission.seen_by, session)

# Removing the UUID from local_uuids ensures this record won't be
# deleted at the end of this function.
Expand All @@ -306,6 +278,11 @@ def __update_submissions(
download_url=submission.download_url,
)
session.add(ns)
session.flush()
if model == File:
add_seen_file_records(ns.id, submission.seen_by, session)
elif model == Message:
add_seen_message_records(ns.id, submission.seen_by, session)
logger.debug(f"Added {model.__name__} {submission.uuid}")

# The uuids remaining in local_uuids do not exist on the remote server, so
Expand All @@ -318,6 +295,72 @@ def __update_submissions(
session.commit()


def add_seen_file_records(file_id: int, journalist_uuids: List[str], session: Session) -> None:
"""
Add a seen record for each journalist that saw the file.
"""
for journalist_uuid in journalist_uuids:
journalist = session.query(User).filter_by(uuid=journalist_uuid).one_or_none()

# Do not add seen record if journalist is missing from the local db. If the
# journalist account needs to be created or deleted, wait until the server says so.
if not journalist:
return

seen_file = (
session.query(SeenFile)
.filter_by(file_id=file_id, journalist_id=journalist.id)
.one_or_none()
)
if not seen_file:
seen_file = SeenFile(file_id=file_id, journalist_id=journalist.id)
session.add(seen_file)


def add_seen_message_records(msg_id: int, journalist_uuids: List[str], session: Session) -> None:
"""
Add a seen record for each journalist that saw the message.
"""
for journalist_uuid in journalist_uuids:
journalist = session.query(User).filter_by(uuid=journalist_uuid).one_or_none()

# Do not add seen record if journalist is missing from the local db. If the
# journalist account needs to be created or deleted, wait until the server says so.
if not journalist:
return

seen_message = (
session.query(SeenMessage)
.filter_by(message_id=msg_id, journalist_id=journalist.id)
.one_or_none()
)
if not seen_message:
seen_message = SeenMessage(message_id=msg_id, journalist_id=journalist.id)
session.add(seen_message)


def add_seen_reply_records(reply_id: int, journalist_uuids: List[str], session: Session) -> None:
"""
Add a seen record for each journalist that saw the reply.
"""
for journalist_uuid in journalist_uuids:
journalist = session.query(User).filter_by(uuid=journalist_uuid).one_or_none()

# Do not add seen record if journalist is missing from the local db. If the
# journalist account needs to be created or deleted, wait until the server says so.
if not journalist:
return

seen_reply = (
session.query(SeenReply)
.filter_by(reply_id=reply_id, journalist_id=journalist.id)
.one_or_none()
)
if not seen_reply:
seen_reply = SeenReply(reply_id=reply_id, journalist_id=journalist.id)
session.add(seen_reply)


def update_replies(
remote_replies: List[SDKReply], local_replies: List[Reply], session: Session, data_dir: str
) -> None:
Expand Down Expand Up @@ -365,25 +408,7 @@ def update_replies(
lazy_setattr(local_reply, "size", reply.size)
lazy_setattr(local_reply, "filename", reply.filename)

# Add seen record if one doesn't yet exist
for journalist_uuid in reply.seen_by:
journalist = session.query(User).filter_by(uuid=journalist_uuid).one_or_none()

# Do not add seen record if journalist is missing from the local db. If the
# journalist accont was deleted, wait until the server says so.
if not journalist:
return

journalist_id = journalist.id

seen_reply = (
session.query(SeenReply)
.filter_by(reply_id=local_reply.id, journalist_id=journalist_id)
.one_or_none()
)
if not seen_reply:
seen_reply = SeenReply(reply_id=local_reply.id, journalist_id=journalist_id)
session.add(seen_reply)
add_seen_reply_records(local_reply.id, reply.seen_by, session)

del local_replies_by_uuid[reply.uuid]
logger.debug("Updated reply {}".format(reply.uuid))
Expand All @@ -402,6 +427,9 @@ def update_replies(
size=reply.size,
)
session.add(nr)
session.flush()

add_seen_reply_records(nr.id, reply.seen_by, session)

# All replies fetched from the server have succeeded in being sent,
# so we should delete the corresponding draft locally if it exists.
Expand Down
22 changes: 22 additions & 0 deletions tests/api_jobs/test_seen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from securedrop_client.api_jobs.seen import SeenJob
from tests import factory


def test_seen(homedir, mocker, session, source):
"""
Check if we call add_star method if a source is not stared.
"""
file = factory.File(id=1, source=source["source"])
message = factory.Message(id=2, source=source["source"])
reply = factory.Reply(source=factory.Source())
session.add(file)
session.add(message)
session.add(reply)

api_client = mocker.MagicMock()

job = SeenJob([file], [message], [reply])

job.call_api(api_client, session)

api_client.seen.assert_called_once_with([file], [message], [reply])
54 changes: 52 additions & 2 deletions tests/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List

from sdclientapi import Reply as SDKReply
from sdclientapi import Submission as SDKSubmission
from sdclientapi import Source as SDKSource

from securedrop_client import db
Expand Down Expand Up @@ -168,7 +169,6 @@ def call_api(self, api_client, session):


def RemoteSource(**attrs):

with open(os.path.join(os.path.dirname(__file__), "files", "test-key.gpg.pub.asc")) as f:
pub_key = f.read()

Expand Down Expand Up @@ -196,20 +196,70 @@ def RemoteSource(**attrs):


def RemoteReply(**attrs):

source_url = "/api/v1/sources/{}".format(str(uuid.uuid4()))
defaults = dict(
filename="1-reply.filename",
journalist_uuid=str(uuid.uuid4()),
journalist_username="test",
journalist_first_name="",
journalist_last_name="",
file_counter=1,
is_deleted_by_source=False,
reply_url="test",
size=1234,
uuid=str(uuid.uuid4()),
source_url=source_url,
seen_by=[],
)

defaults.update(attrs)

return SDKReply(**defaults)


def RemoteFile(**attrs):
source_url = "/api/v1/sources/{}".format(str(uuid.uuid4()))
defaults = dict(
source_uuid="user-uuid-1",
download_url="test",
submission_url="test",
filename="1-submission.filename",
is_read=False,
file_counter=1,
is_deleted_by_source=False,
reply_url="test",
size=1234,
is_decrypted=True,
is_downloaded=True,
uuid=str(uuid.uuid4()),
source_url=source_url,
seen_by=[],
)

defaults.update(attrs)

return SDKSubmission(**defaults)


def RemoteMessage(**attrs):
source_url = "/api/v1/sources/{}".format(str(uuid.uuid4()))
defaults = dict(
source_uuid="user-uuid-1",
download_url="test",
submission_url="test",
filename="1-submission.filename",
is_read=False,
file_counter=1,
is_deleted_by_source=False,
reply_url="test",
size=1234,
is_decrypted=True,
is_downloaded=True,
uuid=str(uuid.uuid4()),
source_url=source_url,
seen_by=[],
)

defaults.update(attrs)

return SDKSubmission(**defaults)
Loading

0 comments on commit 397998b

Please sign in to comment.