Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from common.plugins import jwt, mail
from database.database import db
from routes.auth import auth
from routes.puzzle import puzzle
from routes.leaderboard import leaderboard
from routes.puzzle import puzzle
from routes.user import user


Expand All @@ -28,7 +28,7 @@ def handle_exception(error):
return response


def create_app():
def create_app(config={}):
app = Flask(__name__)
CORS(app)

Expand All @@ -53,6 +53,9 @@ def create_app():
app.config["MAIL_USE_TLS"] = True
app.config["MAIL_USE_SSL"] = False

for key, value in config.items():
app.config[key] = value

app.after_request(update_token)

# Initialise plugins
Expand Down
24 changes: 24 additions & 0 deletions backend/common/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@

## EMAIL VERIFICATION

def add_verification(data, code):
# We use a pipeline here to ensure these instructions are atomic
pipeline = cache.pipeline()

pipeline.hset(f"register:{code}", mapping=data)
pipeline.expire(f"register:{code}", timedelta(hours=1))

pipeline.execute()

def get_verification(code):
key = f"register:{code}"

if not cache.exists(key):
return None

result = {}

for key, value in cache.hgetall(key).items():
result[key.decode()] = value.decode()

return result

## LOCKOUT

def register_incorrect(id):
Expand Down Expand Up @@ -46,5 +68,7 @@ def is_blocked(id):
token = cache.get(f"block_{id}")
return token is not None

## GENERAL FUNCTIONS

def clear_redis():
cache.flushdb()
26 changes: 7 additions & 19 deletions backend/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itsdangerous import URLSafeTimedSerializer

from common.exceptions import AuthError, InvalidError, RequestError
from common.redis import cache
from common.redis import add_verification, get_verification
from database.user import add_user, email_exists, fetch_user, get_user_info, username_exists

hasher = PasswordHasher(
Expand Down Expand Up @@ -64,31 +64,19 @@ def register(email, username, password):
"password": hashed
}

# We use a pipeline here to ensure these instructions are atomic
pipeline = cache.pipeline()

pipeline.hset(f"register:{code}", mapping=data)
pipeline.expire(f"register:{code}", timedelta(hours=1))

pipeline.execute()
add_verification(data, code)

return code

@staticmethod
def register_verify(token):
cache_key = f"register:{token}"
def register_verify(code):
result = get_verification(code)

if not cache.exists(cache_key):
if result is None:
raise AuthError("Token expired or does not correspond to registering user")

result = cache.hgetall(cache_key)
stringified = {}

for key, value in result.items():
stringified[key.decode()] = value.decode()

id = add_user(stringified["email"], stringified["username"], stringified["password"])
return User(id, stringified["email"], stringified["username"], stringified["password"])
id = add_user(result["email"], result["username"], result["password"])
return User(id, result["email"], result["username"], result["password"])

@staticmethod
def login(email, password):
Expand Down
20 changes: 9 additions & 11 deletions backend/test/auth/register_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import email
import os
import poplib
import requests

# Imports for pytest
from pytest_mock import mocker

from test.helpers import clear_all, db_add_user
from test.fixtures import app, client
from test.mock.mock_mail import mailbox


def register(json):
Expand Down Expand Up @@ -64,15 +65,13 @@ def test_duplicate_username(client):
assert response.status_code == 400


def test_success(client):
def test_register_success(client, mocker):
mocker.patch("routes.auth.mail", mailbox)

clear_all()

# Check that we get an email sent
mailbox = poplib.POP3("pop3.mailtrap.io", 1100)
mailbox.user(os.environ["MAILTRAP_USERNAME"])
mailbox.pass_(os.environ["MAILTRAP_PASSWORD"])

(before, _) = mailbox.stat()
before = len(mailbox.messages)

# Register normally
response = client.post("/auth/register", json={
Expand All @@ -84,12 +83,11 @@ def test_success(client):
assert response.status_code == 200

# Check that an email was in fact sent
(after, _) = mailbox.stat()
after = len(mailbox.messages)

assert after == before + 1

# Verify recipient
raw_email = b"\n".join(mailbox.retr(1)[1])
parsed_email = email.message_from_bytes(raw_email)
parsed_email = mailbox.get_message(-1)

assert parsed_email["To"] == "asdfghjkl@gmail.com"
15 changes: 7 additions & 8 deletions backend/test/auth/register_verify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import common
from test.helpers import clear_all
from test.fixtures import app, client
from test.mock.mock_mail import mailbox

## HELPER FUNCTIONS

Expand All @@ -38,6 +39,8 @@ def test_invalid_token(client):
# TODO: try working on this, if not feasible delete this test and test manually
@pytest.mark.skip()
def test_token_expired(client, mocker):
clear_all()

fake = fakeredis.FakeStrictRedis()
mocker.patch.object(common.redis, "cache", return_value=fake)

Expand Down Expand Up @@ -79,7 +82,9 @@ def test_token_expired(client, mocker):

assert response.status_code == 401

def test_success(client):
def test_verify_success(client, mocker):
mocker.patch("routes.auth.mail", mailbox)

clear_all()

register_response = client.post("/auth/register", json={
Expand All @@ -91,13 +96,7 @@ def test_success(client):
assert register_response.status_code == 200

# Check inbox
mailbox = poplib.POP3("pop3.mailtrap.io", 1100)
mailbox.user(os.environ["MAILTRAP_USERNAME"])
mailbox.pass_(os.environ["MAILTRAP_PASSWORD"])

# Check the contents of the email, and harvest the token from there
raw_email = b"\n".join(mailbox.retr(1)[1])
parsed_email = email.message_from_bytes(raw_email)
parsed_email = mailbox.get_message(-1)

# Assuming there's a HTML part
for part in parsed_email.walk():
Expand Down
13 changes: 9 additions & 4 deletions backend/test/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from app import create_app

import pytest
from pytest_mock import mocker

from app import create_app
from test.mock.mock_mail import mailbox

@pytest.fixture()
def app():
app = create_app()
app.config["TESTING"] = True
def app(mocker):
# Mock only where the data is being used
mocker.patch("app.mail", mailbox)
mocker.patch("common.plugins.mail", mailbox)

app = create_app({"TESTING": True})
yield app

@pytest.fixture()
Expand Down
8 changes: 8 additions & 0 deletions backend/test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from database.puzzle import add_part, add_question, add_competition
from models.user import User

## DATABASE FUNCTIONS

def db_add_competition(name):
return add_competition(name)
Expand All @@ -24,6 +25,8 @@ def clear_all():
# Clear database
clear_database()

## HEADER FUNCTIONS

def get_cookie_from_header(response, cookie_name):
cookie_headers = response.headers.getlist("Set-Cookie")

Expand All @@ -44,3 +47,8 @@ def get_cookie_from_header(response, cookie_name):
def generate_csrf_header(response):
csrf_token = get_cookie_from_header(response, "csrf_access_token")["csrf_access_token"]
return {"X-CSRF-TOKEN": csrf_token}

## EMAIL MOCKING

def get_emails():
pass
19 changes: 19 additions & 0 deletions backend/test/mock/mock_mail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import email
import flask_mail

class MockMail:
def __init__(self):
self.ascii_attachments = False
self.messages = []

def init_app(self, app):
app.extensions = getattr(app, 'extensions', {})
app.extensions['mail'] = self

def send(self, message: flask_mail.Message):
self.messages.append(message.as_bytes())

def get_message(self, n):
return email.message_from_bytes(self.messages[n])

mailbox = MockMail()