Skip to content

Commit b488f16

Browse files
committed
POC of MSAL integration
1 parent 1e2fa15 commit b488f16

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

msal/application.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12131213
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
12141214
assert refresh_reason, "It should have been established at this point"
12151215
try:
1216+
if sys.platform == "win32":
1217+
from .wam import _acquire_token_silently, _read_account_by_id
1218+
return _acquire_token_silently(
1219+
"https://{}/{}".format(self.authority.instance, self.authority.tenant), # TODO: What about B2C & ADFS?
1220+
self.client_id,
1221+
_read_account_by_id(account["local_account_id"]),
1222+
" ".join(scopes))
12161223
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12171224
authority, self._decorate_scope(scopes), account,
12181225
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
@@ -1553,6 +1560,23 @@ def acquire_token_interactive(
15531560
and typically contains an "access_token" key.
15541561
- A dict containing an "error" key, when token refresh failed.
15551562
"""
1563+
if sys.platform == "win32":
1564+
from .wam import _signin_interactively
1565+
response = _signin_interactively(
1566+
"https://{}/{}".format(self.authority.instance, self.authority.tenant), # TODO: What about B2C & ADFS?
1567+
self.client_id,
1568+
" ".join(scopes),
1569+
login_hint=login_hint)
1570+
if response.get("error") != "TBD: Broker Unavailable": # TODO
1571+
self.token_cache.add(dict(
1572+
client_id=self.client_id,
1573+
scope=scopes,
1574+
token_endpoint=self.authority.token_endpoint,
1575+
response=response.copy(),
1576+
data=kwargs.get("data", {}),
1577+
))
1578+
return response
1579+
15561580
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
15571581
claims = _merge_claims_challenge_and_capabilities(
15581582
self._client_capabilities, claims_challenge)

msal/token_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def __add(self, event, now=None):
150150
id_token = response.get("id_token")
151151
id_token_claims = (
152152
decode_id_token(id_token, client_id=event["client_id"])
153-
if id_token else {})
153+
if id_token
154+
else response.get("id_token_claims", {})) # Mid-tier would provide id_token_claims
154155
client_info, home_account_id = self.__parse_account(response, id_token_claims)
155156

156157
target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it

msal/wam.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import pymsalruntime # See https://github.com/AzureAD/microsoft-authentication-library-for-cpp/pull/2419/files#diff-d5ea5122ff04e14411a4f695895c923daba73c117d6c8ceb19c4fa3520c3c08a
1111
import win32gui # Came from package pywin32
1212

13-
1413
logger = logging.getLogger(__name__)
1514

1615

@@ -28,13 +27,14 @@ def _read_account_by_id(account_id):
2827
callback_data = _CallbackData()
2928
pymsalruntime.read_account_by_id(
3029
account_id,
30+
"correlation_id",
3131
lambda result, callback_data=callback_data: callback_data.complete(result)
3232
)
3333
callback_data.signal.wait()
3434
return callback_data.auth_result
3535

3636

37-
def _convert_result(result):
37+
def _convert_result(result): # Mimic an on-the-wire response from AAD
3838
error = result.get_error()
3939
if error:
4040
return {
@@ -43,13 +43,15 @@ def _convert_result(result):
4343
error.get_context(), # Available since pymsalruntime 0.0.4
4444
error.get_status(), error.get_error_code(), error.get_tag()),
4545
}
46+
id_token_claims = json.loads(result.get_id_token()) if result.get_id_token() else {}
47+
account = result.get_account()
48+
assert account.get_account_id() == id_token_claims.get("oid"), "Emperical observation" # TBD
4649
return {k: v for k, v in {
4750
"access_token": result.get_access_token(),
4851
"expires_in": result.get_access_token_expiry_time(),
4952
#"scope": result.get_granted_scopes(), # TODO
50-
"id_token_claims": json.loads(result.get_id_token())
51-
if result.get_id_token() else None,
52-
"account": result.get_account(),
53+
"id_token_claims": id_token_claims,
54+
"client_info": account.get_client_info(),
5355
}.items() if v}
5456

5557

0 commit comments

Comments
 (0)