Skip to content

Commit a8d79e2

Browse files
authored
Merge pull request #4 from thephilomaths/add-tests
Bug fix and added user model test
2 parents 9a64c52 + 254a653 commit a8d79e2

File tree

14 files changed

+495
-23
lines changed

14 files changed

+495
-23
lines changed

.isort.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
line_length = 88
33
multi_line_output = 3
44
include_trailing_comma = True
5-
known_third_party = cryptography,flask,sqlalchemy
5+
known_third_party = cryptography,flask,pytest,sqlalchemy

app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@ def shutdown_session() -> None:
1313
"""
1414

1515
db_session.remove()
16+
17+
18+
if __name__ == "__main__":
19+
app.run(host="0.0.0.0", port=5000, debug=True, threaded=True)

requirements.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,30 @@ SQLAlchemy==1.3.19
1616
toml==0.10.1
1717
virtualenv==20.0.31
1818
Werkzeug==1.0.1
19+
appdirs==1.4.4
20+
attrs==20.1.0
21+
cfgv==3.2.0
22+
click==7.1.2
23+
distlib==0.3.1
24+
filelock==3.0.12
25+
Flask==1.1.2
26+
identify==1.4.29
27+
iniconfig==1.0.1
28+
itsdangerous==1.1.0
29+
Jinja2==2.11.2
30+
MarkupSafe==1.1.1
31+
more-itertools==8.5.0
32+
nodeenv==1.5.0
33+
packaging==20.4
34+
pluggy==0.13.1
35+
pre-commit==2.7.1
36+
psycopg2==2.8.5
37+
py==1.9.0
38+
pyparsing==2.4.7
39+
pytest==6.0.1
40+
PyYAML==5.3.1
41+
six==1.15.0
42+
SQLAlchemy==1.3.19
43+
toml==0.10.1
44+
virtualenv==20.0.31
45+
Werkzeug==1.0.1

ssh_manager_backend/app/models/access_control.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@ class AccessControlModel:
1010
def __init__(self):
1111
self.session = db_session()
1212

13+
def create(self, username: str):
14+
"""
15+
Creates an entry in the access_control table for the given user.
16+
17+
:param username
18+
:return: boolean value whether the entry is created or not.
19+
"""
20+
21+
try:
22+
acl_details: AccessControl = AccessControl(
23+
username=username, ip_addresses=[]
24+
)
25+
self.session.add(acl_details)
26+
self.session.commit()
27+
except SQLAlchemyError:
28+
self.session.rollback()
29+
return False
30+
31+
return True
32+
1333
def has_access(self, username: str, ip_address: str) -> bool:
1434
"""
1535
Checks whether a user has access to the provided the list of ip addresses.
@@ -43,18 +63,29 @@ def grant_access(self, username: str, ip_addresses: List[str]) -> bool:
4363

4464
acl_details.ip_addresses += ip_addresses
4565
acl_details.ip_addresses = list(set(acl_details.ip_addresses))
66+
67+
self.session.query(AccessControl).filter(
68+
AccessControl.username == username
69+
).update({"ip_addresses": acl_details.ip_addresses})
70+
4671
self.session.commit()
47-
except [AttributeError, SQLAlchemyError]:
72+
except AttributeError:
73+
return False
74+
except SQLAlchemyError:
75+
self.session.rollback()
4876
return False
4977

5078
return True
5179

52-
def remove_access(self, username: str, ip_addresses: List[str]) -> bool:
80+
def revoke_access(
81+
self, username: str, ip_addresses: List[str], revoke_all: bool = False
82+
) -> bool:
5383
"""
5484
Updates user access.
5585
5686
:param username:
5787
:param ip_addresses:
88+
:param revoke_all:.
5889
:return: booleans value for success/failure.
5990
"""
6091

@@ -63,11 +94,26 @@ def remove_access(self, username: str, ip_addresses: List[str]) -> bool:
6394
AccessControl.username == username
6495
).first()
6596

66-
for ip in ip_addresses:
67-
acl_details.ip_addresses.remove(ip)
97+
if not revoke_all:
98+
for ip in ip_addresses:
99+
try:
100+
acl_details.ip_addresses.remove(ip)
101+
except ValueError:
102+
continue
103+
104+
self.session.query(AccessControl).filter(
105+
AccessControl.username == username
106+
).update({"ip_addresses": acl_details.ip_addresses})
107+
else:
108+
self.session.query(AccessControl).filter(
109+
AccessControl.username == username
110+
).update({"ip_addresses": []})
68111

69112
self.session.commit()
70-
except [AttributeError, SQLAlchemyError]:
113+
except AttributeError:
114+
return False
115+
except SQLAlchemyError:
116+
self.session.rollback()
71117
return False
72118

73119
return True
@@ -85,5 +131,5 @@ def get_all_ips(self, username: str) -> List[str]:
85131
AccessControl.username == username
86132
).first()
87133
return acl_details.ip_addresses
88-
except [AttributeError, SQLAlchemyError]:
134+
except (AttributeError, SQLAlchemyError):
89135
return []

ssh_manager_backend/app/models/keys.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,30 @@ def exists(self, key_name: str) -> bool:
2121
return self.session.query(Key).filter(Key.name == key_name).first() is not None
2222

2323
def create(
24-
self, name: str, encrypted_key: bytes, key_hash: str, user: User
24+
self, name: str, encrypted_key: bytes, key_hash: str, user_id: int
2525
) -> bool:
2626
"""
2727
Creates a key in database.
2828
2929
:param name:
3030
:param encrypted_key:
3131
:param key_hash:
32-
:param user:
32+
:param user_id:
3333
:return: Boolean value indicating success/failure.
3434
"""
3535

3636
try:
3737
key: Key = Key(
38-
name=name, encrypted_key=encrypted_key, key_hash=key_hash, user=user
38+
name=name,
39+
encrypted_key=encrypted_key,
40+
key_hash=key_hash,
41+
user_id=user_id,
3942
)
4043

4144
self.session.add(key)
4245
self.session.commit()
4346
except SQLAlchemyError:
47+
self.session.rollback()
4448
return False
4549

4650
return True

ssh_manager_backend/app/models/keys_mapping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ def exists(self, ip_address: str) -> bool:
2525
is not None
2626
)
2727

28-
def create(self, ip_address: str, key_name: str, key: Key) -> bool:
28+
def create(self, ip_address: str, key_name: str) -> bool:
2929
"""
3030
Creates a key mapping in db.
3131
3232
:param ip_address:
3333
:param key_name:
34-
:param key: Key object
3534
:return: Boolean value indicating success/failure.
3635
"""
3736

3837
try:
39-
key_mapping = KeyMapping(key_name=key_name, ip_address=ip_address, key=key)
38+
key_mapping = KeyMapping(key_name=key_name, ip_address=ip_address)
4039
self.session.add(key_mapping)
4140
self.session.commit()
4241
except SQLAlchemyError:
42+
self.session.rollback()
4343
return False
4444

4545
return True

ssh_manager_backend/app/models/user.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def create(
9494
self.session.add(user)
9595
self.session.commit()
9696
except SQLAlchemyError:
97+
self.session.rollback()
9798
return False
9899

99100
return True
@@ -124,3 +125,23 @@ def get_user(self, username: str) -> Union[None, User]:
124125
"""
125126

126127
return self.session.query(User).filter(User.username == username).first()
128+
129+
def destroy_user(self, username: str) -> bool:
130+
"""
131+
Deletes a user.
132+
133+
:param username:
134+
:return: Boolean value indicating success/failure.
135+
"""
136+
137+
try:
138+
user: User = self.session.query(User).filter(
139+
User.username == username
140+
).first()
141+
self.session.delete(user)
142+
self.session.commit()
143+
except SQLAlchemyError:
144+
self.session.rollback()
145+
return False
146+
147+
return True

ssh_manager_backend/db/database.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from sqlalchemy.ext.declarative import declarative_base
33
from sqlalchemy.orm import scoped_session, sessionmaker
44

5-
DB_URI = "postgresql+psycopg2://postgres:ssh_manager@localhost/ssh_manager_dev"
6-
engine = create_engine(DB_URI, convert_unicode=True, echo=True)
5+
DB_URI = (
6+
"postgresql+psycopg2://postgres:vP28ObNJLhb5qFDe@35.222.241.198/ssh_manager_test"
7+
# "postgresql+psycopg2://ssh_manager:pass@localhost/ssh_manager_dev"
8+
)
9+
engine = create_engine(DB_URI, echo=True)
710
db_session = scoped_session(
811
sessionmaker(autocommit=False, autoflush=False, bind=engine)
912
)

ssh_manager_backend/db/schema.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import ARRAY, Binary, Boolean, Column, ForeignKey, Integer, String
1+
from sqlalchemy import ARRAY, Boolean, Column, ForeignKey, Integer, LargeBinary, String
22
from sqlalchemy.orm import relationship
33

44
from ssh_manager_backend.db.database import Base
@@ -12,12 +12,12 @@ class User(Base):
1212
username = Column(String, unique=True)
1313
password = Column(String)
1414
admin = Column(Boolean)
15-
encrypted_dek = Column(Binary, unique=True)
16-
iv_for_dek = Column(Binary, unique=True)
17-
salt_for_dek = Column(Binary, unique=True)
18-
iv_for_kek = Column(Binary, unique=True)
19-
salt_for_kek = Column(Binary, unique=True)
20-
salt_for_password = Column(Binary, unique=True)
15+
encrypted_dek = Column(LargeBinary, unique=True)
16+
iv_for_dek = Column(LargeBinary, unique=True)
17+
salt_for_dek = Column(LargeBinary, unique=True)
18+
iv_for_kek = Column(LargeBinary, unique=True)
19+
salt_for_kek = Column(LargeBinary, unique=True)
20+
salt_for_password = Column(LargeBinary, unique=True)
2121
keys = relationship("Key", backref="users")
2222
access_control = relationship("AccessControl", backref="users")
2323

@@ -34,7 +34,7 @@ class Key(Base):
3434

3535
id = Column(Integer, primary_key=True)
3636
name = Column(String, unique=True)
37-
encrypted_key = Column(Binary, unique=True)
37+
encrypted_key = Column(LargeBinary, unique=True)
3838
key_hash = Column(String, unique=True)
3939
user_id = Column(Integer, ForeignKey("users.id"))
4040
user = relationship("User")

tests/models/acl_model_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from ssh_manager_backend.app.models.access_control import AccessControlModel
6+
from ssh_manager_backend.app.models.user import UserModel
7+
from tests.test_ssh_manager_backend import db_cleanup
8+
9+
10+
class TestAccessControlModel:
11+
@pytest.fixture
12+
def cleanup(self):
13+
yield
14+
db_cleanup()
15+
16+
def test_create(self):
17+
acl: AccessControlModel = AccessControlModel()
18+
user: UserModel = UserModel()
19+
20+
name: str = "test_user"
21+
username = "test_username"
22+
password: str = "test_password"
23+
admin: bool = False
24+
encrypted_dek: bytes = b"test_encrypted_dek"
25+
iv_for_dek: bytes = b"test_iv_for_dek"
26+
salt_for_dek: bytes = b"test_salt_for_dek"
27+
iv_for_kek: bytes = b"test_iv_for_kek"
28+
salt_for_kek: bytes = b"test_salt_for_kek"
29+
salt_for_password: bytes = b"test_salt_for_password"
30+
31+
assert acl.create(username=username) is False
32+
33+
assert (
34+
user.create(
35+
name=name,
36+
username=username,
37+
password=password,
38+
admin=admin,
39+
encrypted_dek=encrypted_dek,
40+
iv_for_dek=iv_for_dek,
41+
salt_for_dek=salt_for_dek,
42+
iv_for_kek=iv_for_kek,
43+
salt_for_kek=salt_for_kek,
44+
salt_for_password=salt_for_password,
45+
)
46+
is True
47+
)
48+
49+
assert acl.create(username=username) is True
50+
51+
def test_grant_access(self):
52+
acl: AccessControlModel = AccessControlModel()
53+
username: str = "test_username"
54+
ip_addresses: List[str] = ["1.1.1.1", "1.0.0.1"]
55+
56+
assert acl.grant_access(username=username, ip_addresses=ip_addresses) is True
57+
58+
assert (
59+
acl.grant_access(
60+
username="non_existent_username", ip_addresses=ip_addresses
61+
)
62+
is False
63+
)
64+
65+
assert sorted(acl.get_all_ips(username=username)) == sorted(ip_addresses)
66+
67+
def test_revoke_access(self, cleanup):
68+
acl: AccessControlModel = AccessControlModel()
69+
username: str = "test_username"
70+
ip_addresses: List[str] = ["1.1.1.1", "1.0.0.1"]
71+
72+
assert (
73+
acl.revoke_access(username=username, ip_addresses=[ip_addresses[0]]) is True
74+
)
75+
76+
assert (
77+
acl.revoke_access(username=username, ip_addresses=["non_existent_ip"])
78+
is True
79+
)
80+
81+
assert (
82+
acl.revoke_access(
83+
username="non_existent_username", ip_addresses=ip_addresses
84+
)
85+
is False
86+
)
87+
88+
assert acl.get_all_ips(username=username) == [ip_addresses[1]]

0 commit comments

Comments
 (0)