@@ -10,6 +10,26 @@ class AccessControlModel:
1010 def __init__ (self ):
1111 self .session = db_session ()
1212
13+ def create (self , username : str ):
14+ """
15+ Creates an entry in the access_control table for the given user.
16+
17+ :param username
18+ :return: boolean value whether the entry is created or not.
19+ """
20+
21+ try :
22+ acl_details : AccessControl = AccessControl (
23+ username = username , ip_addresses = []
24+ )
25+ self .session .add (acl_details )
26+ self .session .commit ()
27+ except SQLAlchemyError :
28+ self .session .rollback ()
29+ return False
30+
31+ return True
32+
1333 def has_access (self , username : str , ip_address : str ) -> bool :
1434 """
1535 Checks whether a user has access to the provided the list of ip addresses.
@@ -43,18 +63,29 @@ def grant_access(self, username: str, ip_addresses: List[str]) -> bool:
4363
4464 acl_details .ip_addresses += ip_addresses
4565 acl_details .ip_addresses = list (set (acl_details .ip_addresses ))
66+
67+ self .session .query (AccessControl ).filter (
68+ AccessControl .username == username
69+ ).update ({"ip_addresses" : acl_details .ip_addresses })
70+
4671 self .session .commit ()
47- except [AttributeError , SQLAlchemyError ]:
72+ except AttributeError :
73+ return False
74+ except SQLAlchemyError :
75+ self .session .rollback ()
4876 return False
4977
5078 return True
5179
52- def remove_access (self , username : str , ip_addresses : List [str ]) -> bool :
80+ def revoke_access (
81+ self , username : str , ip_addresses : List [str ], revoke_all : bool = False
82+ ) -> bool :
5383 """
5484 Updates user access.
5585
5686 :param username:
5787 :param ip_addresses:
88+ :param revoke_all:.
5889 :return: booleans value for success/failure.
5990 """
6091
@@ -63,11 +94,26 @@ def remove_access(self, username: str, ip_addresses: List[str]) -> bool:
6394 AccessControl .username == username
6495 ).first ()
6596
66- for ip in ip_addresses :
67- acl_details .ip_addresses .remove (ip )
97+ if not revoke_all :
98+ for ip in ip_addresses :
99+ try :
100+ acl_details .ip_addresses .remove (ip )
101+ except ValueError :
102+ continue
103+
104+ self .session .query (AccessControl ).filter (
105+ AccessControl .username == username
106+ ).update ({"ip_addresses" : acl_details .ip_addresses })
107+ else :
108+ self .session .query (AccessControl ).filter (
109+ AccessControl .username == username
110+ ).update ({"ip_addresses" : []})
68111
69112 self .session .commit ()
70- except [AttributeError , SQLAlchemyError ]:
113+ except AttributeError :
114+ return False
115+ except SQLAlchemyError :
116+ self .session .rollback ()
71117 return False
72118
73119 return True
@@ -85,5 +131,5 @@ def get_all_ips(self, username: str) -> List[str]:
85131 AccessControl .username == username
86132 ).first ()
87133 return acl_details .ip_addresses
88- except [ AttributeError , SQLAlchemyError ] :
134+ except ( AttributeError , SQLAlchemyError ) :
89135 return []
0 commit comments