Skip to content

Commit

Permalink
Merge pull request #441 from TeskaLabs/refactoring/remove-session-ada…
Browse files Browse the repository at this point in the history
…pter-service-dependency

Remove session adapter's dependency on session service
  • Loading branch information
byewokko authored Feb 13, 2025
2 parents e4be8c1 + 3e5b65f commit eb43e3c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v25.05

### Pre-releases
- v25.05-alpha4
- v25.05-alpha3
- v25.05-alpha2
- v25.05-alpha1
Expand All @@ -14,6 +15,9 @@
- Configurable default tenant roles (#436, v25.05-alpha3)
- Provisioning service initialization uses system Session object (#439, v25.05-alpha2)

### Refactoring
- Remove session adapter's dependency on session service (#441, `v25.05-alpha4`)

---


Expand Down
26 changes: 2 additions & 24 deletions seacatauth/session/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@
from .. import AuditLogger
from ..authz.rbac.service import RBACService

#

L = logging.getLogger(__name__)

#


@dataclasses.dataclass
class SessionData:
Expand Down Expand Up @@ -172,9 +169,7 @@ class Batman:

EncryptedPrefix = b"$aescbc$"

def __init__(self, session_svc, session_dict):
self._decrypt_encrypted_identifiers(session_dict, session_svc)

def __init__(self, session_dict):
self.Session = self._deserialize_session_data(session_dict)
self.Id = self.Session.Id
self.SessionId = self.Session.Id
Expand Down Expand Up @@ -335,23 +330,6 @@ def has_global_resource_access(self, resource_id: str) -> bool:
and RBACService.has_resource_access(self.Authorization.Authz, None, {resource_id})
)

def _decrypt_encrypted_identifiers(self, session_dict, session_svc):
# Decrypt sensitive fields
for field in self.EncryptedIdentifierFields:
# BACK COMPAT: Handle nested dictionaries
obj = session_dict
keys = field.split(".")
for key in keys[:-1]:
if key not in obj:
break
obj = obj[key]
else:
# BACK COMPAT: Keep values without prefix raw
# TODO: Remove support once proper m2m tokens are in place
value = obj.get(keys[-1])
if value is not None and value.startswith(self.EncryptedPrefix):
obj[keys[-1]] = session_svc.aes_decrypt(value[len(self.EncryptedPrefix):])

@classmethod
def _deserialize_session_data(cls, session_dict):
return SessionData(
Expand Down Expand Up @@ -506,7 +484,7 @@ def rest_get(session_dict):

# TODO: Use ASAB Authorization, this is a temporary solution.
def build_system_session(session_service, session_id):
session = SessionAdapter(session_service, {
session = SessionAdapter({
SessionAdapter.FN.SessionId: session_id,
SessionAdapter.FN.Version: 0,
SessionAdapter.FN.CreatedAt: datetime.datetime.now(datetime.UTC),
Expand Down
2 changes: 1 addition & 1 deletion seacatauth/session/algorithmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def _build_anonymous_session(
SessionAdapter.FN.Authentication.IsAnonymous: True,
}
await self._add_session_authz(session_dict, client_dict["anonymous_cid"], scope)
return SessionAdapter(self, session_dict)
return SessionAdapter(session_dict)


async def _add_session_authz(self, session_dict: dict, credentials_id: str, scope: set):
Expand Down
37 changes: 29 additions & 8 deletions seacatauth/session/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@
cookie_session_builder
)

#

L = logging.getLogger(__name__)

#


class SessionService(asab.Service):

SessionCollection = "s"

def __init__(self, app, service_name="seacatauth.SessionService"):
Expand Down Expand Up @@ -170,7 +166,8 @@ async def _delete_expired_sessions(self):
# TODO: Improve performance - each self.delete(session_id) call searches for potential subsessions!
expired = []
async for session in self._iterate_raw(
query_filter={SessionAdapter.FN.Session.Expiration: {"$lt": datetime.datetime.now(datetime.timezone.utc)}}
query_filter={
SessionAdapter.FN.Session.Expiration: {"$lt": datetime.datetime.now(datetime.timezone.utc)}}
):
expired.append(session["_id"])

Expand Down Expand Up @@ -306,8 +303,10 @@ async def get_by(self, key: str, value):
if expires_at < datetime.datetime.now(datetime.timezone.utc):
raise exceptions.SessionNotFoundError("Session expired.", query={key: value})

session_dict = self._decrypt_encrypted_session_identifiers(session_dict)

try:
session = SessionAdapter(self, session_dict)
session = SessionAdapter(session_dict)
except Exception as e:
L.exception("Failed to create SessionAdapter from database object.", struct_data={
"sid": session_dict.get("_id"),
Expand All @@ -334,8 +333,10 @@ async def get(self, session_id):
if session_dict[SessionAdapter.FN.Session.Expiration] < datetime.datetime.now(datetime.timezone.utc):
raise exceptions.SessionNotFoundError("Session expired.", session_id=session_id)

session_dict = self._decrypt_encrypted_session_identifiers(session_dict)

try:
session = SessionAdapter(self, session_dict)
session = SessionAdapter(session_dict)
except Exception as e:
L.exception("Failed to create SessionAdapter from database object.", struct_data={
"sid": session_dict.get("_id"),
Expand Down Expand Up @@ -408,7 +409,7 @@ async def recursive_list(self, page: int = 0, limit: int = None, query_filter=No
count = await collection.count_documents(query_filter)
async for session_dict in self._iterate_raw(page, limit, query_filter):
try:
session = SessionAdapter(self, session_dict).rest_get()
session = SessionAdapter(session_dict).rest_get()
except Exception as e:
L.error("Failed to create SessionAdapter from database object: {}".format(e), struct_data={
"sid": session_dict.get("_id"),
Expand Down Expand Up @@ -530,6 +531,7 @@ async def delete(self, session_id):

# Delete all the session's tokens
await self.TokenService.delete_tokens_by_session_id(session_id)

# TODO: Publish pubsub message for session deletion


Expand Down Expand Up @@ -804,6 +806,7 @@ def aes_encrypt(self, raw_bytes: bytes):
encrypted = iv + (encryptor.update(token) + encryptor.finalize())
return encrypted


def aes_decrypt(self, encrypted_bytes: bytes):
algorithm = cryptography.hazmat.primitives.ciphers.algorithms.AES(self.AESKey)
iv, token = encrypted_bytes[:self.AESBlockSize], encrypted_bytes[self.AESBlockSize:]
Expand All @@ -812,3 +815,21 @@ def aes_decrypt(self, encrypted_bytes: bytes):
decryptor = cipher.decryptor()
raw = iv + (decryptor.update(token) + decryptor.finalize())
return raw


def _decrypt_encrypted_session_identifiers(self, session_dict: dict) -> dict:
for field in SessionAdapter.EncryptedIdentifierFields:
# BACK COMPAT: Handle nested dictionaries
obj = session_dict
keys = field.split(".")
for key in keys[:-1]:
if key not in obj:
break
obj = obj[key]
else:
# BACK COMPAT: Keep values without prefix raw
# TODO: Remove support once proper m2m tokens are in place
value = obj.get(keys[-1])
if value is not None and value.startswith(SessionAdapter.EncryptedPrefix):
obj[keys[-1]] = self.aes_decrypt(value[len(SessionAdapter.EncryptedPrefix):])
return session_dict

0 comments on commit eb43e3c

Please sign in to comment.