Skip to content

Commit

Permalink
Added local sent messages file (per-chat) to allow message editing
Browse files Browse the repository at this point in the history
  • Loading branch information
etrian-dev committed Mar 13, 2022
1 parent df13441 commit 6848ed7
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 58 deletions.
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
- Encryption/decryption of message data with RSA
- Local message storage (sent messages) in plaintext on file or db
- logout function and proper authentication (maybe sessions)
- Fix msgstore to write into ./instance
# Container
- Add volume to make source changes shared between host and container
23 changes: 13 additions & 10 deletions chatroom/Auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

@blueprint.route('/register', methods=['GET', 'POST'])
def register():
'''Register a new user.
'''
error = None
if request.method == 'POST':
# create a new user
Expand All @@ -19,9 +21,12 @@ def register():
# insert the newly created user into the db
conn = db.get_db()
try:
pk_n_bytes = new_user.pub_key[0].to_bytes(new_user.pub_key[0].bit_length() // 8 + 1,byteorder='big')
pk_e_bytes = new_user.pub_key[1].to_bytes(new_user.pub_key[1].bit_length() // 8 + 1,byteorder='big')
pk_d_bytes = new_user.priv_key.to_bytes(new_user.priv_key.bit_length() // 8 + 1,byteorder='big')
cur = conn.execute('''
INSERT INTO Users(username, password, pk_n, pk_e, pk_d)
VALUES (?,?,?,?,?)''', [new_user.username, new_user.password, new_user.pub_key[0], new_user.pub_key[1], new_user.priv_key])
VALUES (?,?,?,?,?)''', [new_user.username, new_user.password, pk_n_bytes, pk_e_bytes, pk_d_bytes])
# gets the rowid of the last row, which is the same as the new user_id
# iff user_id is an integer primary key with autoincrement
new_user.user_id = cur.lastrowid
Expand All @@ -38,6 +43,8 @@ def register():
@blueprint.route('/', methods=['GET'])
@blueprint.route('/login', methods=['GET', 'POST'])
def login():
'''Login into user profile.
'''
error = None
if request.method == 'POST':
username = request.form['username']
Expand All @@ -47,26 +54,22 @@ def login():
error = 'Username or password unspecified'
else:
conn = db.get_db()
cur = conn.cursor()
cur.execute('SELECT user_id, username, password FROM Users WHERE username=?;', [username])
cur = conn.execute('SELECT user_id, username, password FROM Users WHERE username=?;', [username])

user_id = None
row = cur.fetchone()
while row is not None:
if row['password'] == pwd:
user_id = row['user_id']
# create a new session for this user
cur.execute('INSERT INTO Sessions(userref,login_tm) VALUES (?, CURRENT_TIMESTAMP)', [user_id])
conn.commit()
break;
# TODO: implement session creation here
break
row = cur.fetchone()
cur.close()
# User found: redirect to its own chats page
if user_id is not None:
return redirect(url_for('chat.home_user', user_id=user_id))
error = 'Password incorrect'
return render_template('login.html', error=error)
error = 'Password incorrect or user inexistent'
if error is not None:
print(error)
# store error to be shown
flash(error)
return render_template('login.html', error=error)
Expand Down
25 changes: 20 additions & 5 deletions chatroom/Chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from . import db
from . import Msg

from time import time
from datetime import datetime
from sqlite3 import Connection, Cursor, DatabaseError
from json import load

from flask import (
Blueprint, flash, g, redirect, render_template, request, session, url_for
Expand Down Expand Up @@ -120,18 +122,18 @@ def display_chat(user, other):
WHERE (participant1 = ? AND ? = participant2) OR (participant1 = ? AND participant2 = ?);
''', [user, other, other, user])
chat = cur.fetchone()
# fetch all messages
# fetch all messages sent by the other user
cur.execute('''
SELECT * FROM Messages WHERE chatref = ?;
''', [chat['chat_id']])
SELECT * FROM Messages WHERE chatref = ? AND sender == ? ;
''', [chat['chat_id'], other])
msgs_encoded = cur.fetchall()
messages = []
# build breadcrumb
breadcrumb = dict()
breadcrumb['home'] = url_for('chat.home_user', user_id=user)
breadcrumb[chat_info['other_user']] = url_for('chat.display_chat', user=user, other=other)
chat_info['breadcrumb'] = breadcrumb
#decode messages
messages = []
for msg in msgs_encoded:
sender = None
receiver = None
Expand All @@ -145,7 +147,20 @@ def display_chat(user, other):
{'msg_id': msg['msg_id'],
'sender': sender,
'receiver': receiver,
'data': msg['msg_data'].decode(encoding='utf-8')})
'data': Msg.decrypt_message(user, msg['msg_data'])})
# fetch all messages sent by this user
try:
with open(Msg.get_msgstore(user, other), 'r') as msgstore:
sent_messages = load(msgstore)
for msg in sent_messages:
messages.append(
{'msg_id': msg['msg_id'],
'sender': chat_info['this_user'],
'receiver': chat_info['other_user'],
'data': msg['data']})
print(sent_messages)
except FileNotFoundError:
pass # no messages sent by this user yet
chat_info['messages'] = messages

return render_template('messages.html', **chat_info)
Expand Down
121 changes: 92 additions & 29 deletions chatroom/Msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,16 @@ def __init__(self, data: str, stamp: int):
from . import db
from sqlite3 import Connection, Cursor, DatabaseError
from json import (load, loads, dump, dumps)
from os import fspath, path
from os import fspath, path, SEEK_END

from flask import (
Blueprint, flash, g, redirect, render_template, request, session, url_for, make_response
Blueprint, flash, g, current_app, redirect, render_template, request, session, url_for, make_response
)
from flask import Flask
from flask.json import jsonify

blueprint = Blueprint('msg', __name__, url_prefix='/msg')

def get_msgstore(sender: int, recipient: int):
with open(fspath(f'{sender}_{recipient}_sent.json'), 'r') as msgstore:
data = load(msgstore)
# Create the message file
open('sent.json', 'x')

def get_all_messages(sender, receiver):
return f"GET all msg from {sender} to {receiver}"

Expand All @@ -38,22 +33,44 @@ def msg_to_int(msg: str) -> int:
def int_to_msg(num: int) -> str:
"""Simple bijective decoding function from integers to strings.
"""
return str(num.to_bytes(num.bit_length(), byteorder='big'), 'utf-8')
return str(num.to_bytes(num.bit_length() // 8 + 1, byteorder='big'), 'utf-8')

def encrypt_msg(recipient: int, plaintext: str) -> int:
'''Encrypt the plaintext message with recipient's public key.
'''
# get the public key of the recipient
cur = db.get_db().execute('''
SELECT pk_e,pk_n FROM Users WHERE user_id = ?
SELECT pk_e,pk_n,pk_d FROM Users WHERE user_id = ?
''', [recipient])
res = cur.fetchone()
pubkey_n = res['pk_n']
pubkey_e = res['pk_e']
pubkey_n = int.from_bytes(res['pk_n'], byteorder='big')
pubkey_e = int.from_bytes(res['pk_e'], byteorder='big')
# Encrypt this message with the recipient's public key
encrypted_msg = rsa_encrypt(msg_to_int(msg_data))
return enctypted_msg
orig_msg_int = msg_to_int(plaintext)
print('original:',orig_msg_int)
encrypted_msg = rsa.rsa_encrypt(orig_msg_int, pubkey_e, pubkey_n)

pubkey_d = int.from_bytes(res['pk_d'], byteorder='big')
decrypted_msg = rsa.rsa_decrypt(encrypted_msg, pubkey_d, pubkey_n)
assert orig_msg_int == decrypted_msg, "Messages differ"

return encrypted_msg

def decrypt_message(recipient: int, data: bytes) -> str:
'''Decrypts a bytes object directed to this user into a message'''
cur = db.get_db().execute('''
SELECT pk_n,pk_d FROM Users WHERE user_id = ?
''', [recipient])
res = cur.fetchone()
pubkey_n = int.from_bytes(res['pk_n'], byteorder='big')
priv_key = int.from_bytes(res['pk_d'], byteorder='big')
decrypted_msg = rsa.rsa_decrypt(int.from_bytes(data, byteorder='big'), priv_key, pubkey_n)
print('decrypted:', decrypted_msg)
return int_to_msg(decrypted_msg)

def get_msgstore(sender, receiver) -> str:
# TODO: fix
return fspath(f'{current_app.instance_path}/{sender}_{receiver}_sent.json')

@blueprint.route('/<int:sender>/<int:recipient>/', methods=['POST'])
def send_message(sender, recipient):
Expand All @@ -69,15 +86,33 @@ def send_message(sender, recipient):
chat_id = cur.fetchone()['chat_id']
# retrieve the message data
msg_data = request.form['message']

# encrypt
encrypted_msg = encrypt_msg(recipient, msg_data)
# insert the message
cur.execute('''
INSERT INTO Messages(chatref,sender,recipient,msg_data)
VALUES (?, ?, ?, ?);
''', [chat_id, sender, recipient, enctypted_msg.to_bytes()])
''', [chat_id, sender, recipient, encrypted_msg.to_bytes(length=encrypted_msg.bit_length() // 8 + 1, byteorder='big')])
# store the last row id == msg_id in this case
msg_id = cur.lastrowid
db_conn.commit()
print(f"{sender} said: {msg_data} to {recipient}")
# fetch the message id
# Save the message to the msgstore
msg_obj = {"msg_id": msg_id, "sender": sender, "recipient": recipient, "data": msg_data}
file_exists = path.exists(fspath(f'{sender}_{recipient}_sent.json'))
if file_exists:
with open(get_msgstore(sender, recipient), 'r+b') as msgstore:
# cancel array end character
b = msgstore.seek(-1, SEEK_END)
print('seek = ', b)
msgstore.write((',\n' + dumps(msg_obj) + ' ]').encode(encoding='utf-8'))
else:
with open(get_msgstore(sender, recipient), 'a') as msgstore:
msgstore.write('[ ' + dumps(msg_obj) + ' ]')
print('last pos = ', msgstore.tell())

print(f"{sender} said to {recipient}: {msg_data}")
# return an empty response, with a 201 Created code
return redirect(url_for('chat.display_chat', user=sender, other=recipient))
except DatabaseError:
Expand All @@ -86,31 +121,48 @@ def send_message(sender, recipient):

@blueprint.route('/<int:msg_id>', methods=['GET', 'PUT'])
def edit_message(msg_id):
# Get the message sender and receiver
cur = db.get_db().execute('''
SELECT sender,recipient FROM Messages WHERE msg_id = ?;
''', [msg_id])
sender_recipient = cur.fetchone()
# Modify the message
if request.method == 'PUT':
# Get the message sender and receiver
cur = db.get_db().execute('''
SELECT sender,recipient FROM Messages WHERE msg_id = ?;
''', [msg_id])
sender_recipient = cur.fetchone()
# get the new message (plaintext)
newmsg = request.get_json()['msg']
newmsg_data = request.get_json()['msg']
# encrypt the new message and update the db
encrypted_msg = encrypt_msg(sender_recipient['recipient'], newmsg)
encrypted_msg = encrypt_msg(sender_recipient['recipient'], newmsg_data)
cur.execute('''
UPDATE Messages SET msg_data = ? WHERE msg_id = ?;
''', [encrypted_msg.to_bytes(), msg_id])
''', [encrypted_msg.to_bytes(encrypted_msg.bit_length() // 8 + 1, byteorder='big'), msg_id])
db.get_db().commit()
# update the message stored in the msgstore as well
old = None
newmsg = None
with open(get_msgstore(sender_recipient['sender'], sender_recipient['recipient']), 'r') as msgstore:
messages = load(msgstore)
for msg in messages:
if msg['msg_id'] == msg_id:
newmsg = msg
newmsg['data'] = newmsg_data
old = msg
break
messages.remove(old)
messages.append(newmsg)
with open(get_msgstore(sender_recipient['sender'], sender_recipient['recipient']), 'w') as msgstore:
msgstore.write(dumps(messages))
flash(f"Message {msg_id} updated successfully")
url = url_for('chat.display_chat', user=sender_recipient['sender'], other=sender_recipient['recipient'])
return jsonify({"url": url})
# Otherwise render the webpage containing the message to be modified
msg_data = dict()
msg_data['msg_id'] = msg_id
db_conn = db.get_db()
cur = db_conn.execute('''
SELECT msg_data FROM Messages WHERE msg_id = ?;''', [msg_id])
msg_row = cur.fetchone()
msg_data['old_msg'] = msg_row['msg_data'].decode(encoding='utf-8')
with open(get_msgstore(sender_recipient['sender'], sender_recipient['recipient']), 'r') as msgstore:
messages = load(msgstore)
for msg in messages:
if msg['msg_id'] == msg_id:
msg_data['old_msg'] = msg['data']
break
return render_template('edit_message.html', **msg_data)


Expand All @@ -127,5 +179,16 @@ def delete_message(msg_id):
cur.execute('''
DELETE FROM Messages WHERE msg_id = ?;
''', [msg_id])
# Delete from the msgstore as well
to_delete = None
with open(get_msgstore(sender_recipient['sender'], sender_recipient['recipient']), 'r') as msgstore:
messages = load(msgstore)
for msg in messages:
if msg['msg_id'] == msg_id:
to_delete = msg
break
messages.remove(to_delete)
with open(get_msgstore(sender_recipient['sender'], sender_recipient['recipient']), 'w') as msgstore:
msgstore.write(dumps(messages))
db.get_db().commit()
return jsonify({"url": url})
6 changes: 3 additions & 3 deletions chatroom/User.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from secrets import token_bytes
from hashlib import sha3_256

PRIME_LEN = 15
ITERATIONS = 10
SALT_LEN = 16
PRIME_LEN = 512
ITERATIONS = 15
SALT_LEN = 32


class User:
Expand Down
4 changes: 2 additions & 2 deletions chatroom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def create_app(testing=None):
app.config.from_mapping(
SECRET_KEY='dev',
DATABASE=os.path.join(app.instance_path, 'chatroom.sqlite'),
#EXPLAIN_TEMPLATE_LOADING=True,
)
print(app.config['DATABASE'])
# ensure the instance folder exists
# ensure the instance folder exists
try:
os.makedirs(app.instance_path)
except OSError:
Expand Down
6 changes: 3 additions & 3 deletions chatroom/crypto/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def rsa_encrypt(message: int, exponent: int, mod: int) -> int:
that a message greater than modulus is not allowed
"""
if message > mod:
return -1
raise Exception(f"{message} > {mod}!")
else:
return fastext.fastexp(message, exponent, mod)
return fastexp.fastexp(message, exponent, mod)

def rsa_decrypt(ciphertext: int, private_key: int, mod: int) -> int:
"""RSA decryption function.
Expand All @@ -45,7 +45,7 @@ def rsa_decrypt(ciphertext: int, private_key: int, mod: int) -> int:
"""
if ciphertext > mod:
# TODO: check this name
raise IllegalArgument
raise Exception(f"{msg} > {mod}!")
else:
return fastexp.fastexp(ciphertext, private_key, mod)

Expand Down
5 changes: 3 additions & 2 deletions chatroom/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def close_db(e=None):

def init_db():
gdb = get_db().cursor()
with current_app.open_resource('db_schema.sql') as f:
gdb.executescript(f.read().decode('utf8'))
with current_app.open_resource('db_schema.sql', 'r') as f:
contents = f.read()
gdb.executescript(contents)
# TODO: untested
def querydb(query: str, args: list):
'''Generator that queries the database and returns the resulting rows.
Expand Down
6 changes: 3 additions & 3 deletions chatroom/db_schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ CREATE TABLE Users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
password TEXT NOT NULL,
pk_n INTEGER NOT NULL,
pk_e INTEGER NOT NULL,
pk_d INTEGER NOT NULL
pk_n BLOB NOT NULL,
pk_e BLOB NOT NULL,
pk_d BLOB NOT NULL
);

CREATE TABLE Sessions (
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
flask == 2.0.3
flask == 3.0.3
python-dotenv == 0.19.2

0 comments on commit 6848ed7

Please sign in to comment.