diff --git a/msal/wam.py b/msal/wam.py index 123b52bf..61265a8d 100644 --- a/msal/wam.py +++ b/msal/wam.py @@ -8,6 +8,7 @@ import logging import pymsalruntime # See https://github.com/AzureAD/microsoft-authentication-library-for-cpp/pull/2419/files#diff-d5ea5122ff04e14411a4f695895c923daba73c117d6c8ceb19c4fa3520c3c08a +import win32gui # Came from package pywin32 logger = logging.getLogger(__name__) @@ -32,6 +33,19 @@ def _read_account_by_id(account_id): callback_data.signal.wait() return callback_data.auth_result + +def _convert_result(result): + return {k: v for k, v in { + "error": result.get_error(), + "access_token": result.get_access_token(), + #"expires_in": result.get_access_token_expiry_time(), # TODO + #"scope": result.get_granted_scopes(), # TODO + "id_token_claims": json.loads(result.get_id_token()) + if result.get_id_token() else None, + "account": result.get_account(), + }.items() if v} + + def _signin_silently(authority, client_id, scope): params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority) params.set_requested_scopes(scope or "https://graph.microsoft.com/.default") @@ -43,13 +57,25 @@ def _signin_silently(authority, client_id, scope): callback_data.signal.wait() return callback_data.auth_result -def _signin_interactively(): +def _signin_interactively( + authority, client_id, scope, + login_hint=None, + window=None, + ): + params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority) + params.set_requested_scopes(scope or "https://graph.microsoft.com/.default") + params.set_redirect_uri( + "https://login.microsoftonline.com/common/oauth2/nativeclient") callback_data = _CallbackData() pymsalruntime.signin_interactively( - # TODO: Add other input parameters + window or win32gui.GetDesktopWindow(), # TODO: Remove win32gui + params, + "correlation", # TODO + login_hint or "", # Account hint lambda result, callback_data=callback_data: callback_data.complete(result)) callback_data.signal.wait() - return callback_data.auth_result + return _convert_result(callback_data.auth_result) + def _acquire_token_silently(authority, client_id, account, scope): params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority) @@ -60,18 +86,10 @@ def _acquire_token_silently(authority, client_id, account, scope): "correlation", # TODO lambda result, callback_data=callback_data: callback_data.complete(result)) callback_data.signal.wait() - result = callback_data.auth_result - return {k: v for k, v in { - "error": result.get_error(), - "access_token": result.get_access_token(), - #"expires_in": result.get_access_token_expiry_time(), # TODO - #"scope": result.get_granted_scopes(), # TODO - "id_token_claims": json.loads(result.get_id_token()) - if result.get_id_token() else None, - "account": result.get_account(), - }.items() if v} + return _convert_result(callback_data.auth_result) + -def _acquire_token_interactive( +def _acquire_token_interactively( authority, client_id, account, @@ -92,7 +110,8 @@ def _acquire_token_interactive( params.set_claims(claims_challenge) # TODO: Wire up other input parameters too callback_data = _CallbackData() - pymsalruntime.signin_interactively( + pymsalruntime.acquire_token_interactively( + window, # TODO params, "correlation", # TODO account, @@ -105,31 +124,14 @@ def acquire_token_interactive( authority, # type: str client_id, # type: str scopes, # type: list[str] + login_hint=None, **kwargs): """MSAL Python's acquire_token_interactive() will call this""" - scope = " ".join(scopes) - result = _signin_silently(authority, client_id) - logger.debug("%s, %s, %s, %s, %s", client_id, scope, result, dir(result), result.get_error()) - if not result.get_account(): - result = _signin_interactively(authority, client_id) - if not result.get_account(): - return {"error": result.get_error()} # TODO - - result = _acquire_token_silently( - authority, client_id, account, scope, **kwargs) - if not result.get_access_token(): - result = _acquire_token_interactive( - authority, client_id, account, scope, **kwargs) - if not result.get_access_token(): - return {"error": result.get_error()} # TODO - # TODO: Also store the tokens and account into MSAL's token cache - return {k: v for k, v in { - "access_token": result.get_access_token(), - "token_type": "Bearer", # TODO: TBD - "expires_in": result.get_access_token_expiry_time(), - "id_token": result.get_id_token(), - "scope": result.get_granted_scopes(), - } if v is not None} + return _signin_interactively( + authority, + client_id, + " ".join(scopes), + login_hint=login_hint) def acquire_token_silent( diff --git a/tests/test_wam.py b/tests/test_wam.py index 6e53492e..bf3ed861 100644 --- a/tests/test_wam.py +++ b/tests/test_wam.py @@ -9,21 +9,17 @@ class TestWam(unittest.TestCase): client_id = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" # A well-known app - @unittest.skip("Not yet implemented") def test_acquire_token_interactive(self): - acquire_token_interactive( + result = acquire_token_interactive( "https://login.microsoftonline.com/common", - #"my_client_id", "26a7ee05-5602-4d76-a7ba-eae8b7b67941", - #["foo", "bar"], ["https://graph.microsoft.com/.default"], ) + self.assertIsNotNone(result.get("access_token")) def test_acquire_token_silent(self): result = acquire_token_silent( "https://login.microsoftonline.com/common", - #"my_client_id", - #self.client_id, "26a7ee05-5602-4d76-a7ba-eae8b7b67941", ["https://graph.microsoft.com/.default"], #{"some_sort_of_id": "placeholder"}, # TODO