Skip to content

Commit

Permalink
Merge pull request #569 from AzureAD/device-flow-and-msal-runtime
Browse files Browse the repository at this point in the history
acquire_token_silent() shall not invoke broker if the account was not established by broker
  • Loading branch information
rayluo authored Nov 3, 2023
2 parents 88c4bf8 + 6973a99 commit 9c124b5
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 16 deletions.
48 changes: 44 additions & 4 deletions msal/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
Usage 1: Run it on the fly.
python -m msal
Note: We choose to not define a console script to avoid name conflict.
Usage 2: Build an all-in-one executable file for bug bash.
shiv -e msal.__main__._main -o msaltest-on-os-name.pyz .
Note: We choose to not define a console script to avoid name conflict.
"""
import base64, getpass, json, logging, sys, msal
import base64, getpass, json, logging, sys, os, atexit, msal

_token_cache_filename = "msal_cache.bin"
global_cache = msal.SerializableTokenCache()
atexit.register(lambda:
open(_token_cache_filename, "w").write(global_cache.serialize())
# Hint: The following optional line persists only when state changed
if global_cache.has_state_changed else None
)

_AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
_VISUAL_STUDIO = "04f0c124-f2bc-4f59-8241-bf6df9866bbd"
Expand Down Expand Up @@ -66,7 +74,7 @@ def _select_account(app):
if accounts:
return _select_options(
accounts,
option_renderer=lambda a: a["username"],
option_renderer=lambda a: "{}, came from {}".format(a["username"], a["account_source"]),
header="Account(s) already signed in inside MSAL Python:",
)
else:
Expand All @@ -76,7 +84,7 @@ def _acquire_token_silent(app):
"""acquire_token_silent() - with an account already signed into MSAL Python."""
account = _select_account(app)
if account:
print_json(app.acquire_token_silent(
print_json(app.acquire_token_silent_with_error(
_input_scopes(),
account=account,
force_refresh=_input_boolean("Bypass MSAL Python's token cache?"),
Expand Down Expand Up @@ -122,6 +130,15 @@ def _acquire_token_by_username_password(app):
print_json(app.acquire_token_by_username_password(
_input("username: "), getpass.getpass("password: "), scopes=_input_scopes()))

def _acquire_token_by_device_flow(app):
"""acquire_token_by_device_flow() - Note that this one does not go through broker"""
flow = app.initiate_device_flow(scopes=_input_scopes())
print(flow["message"])
sys.stdout.flush() # Some terminal needs this to ensure the message is shown
input("After you completed the step above, press ENTER in this console to continue...")
result = app.acquire_token_by_device_flow(flow) # By default it will block
print_json(result)

_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
_SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
_SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
Expand Down Expand Up @@ -182,6 +199,27 @@ def _exit(app):

def _main():
print("Welcome to the Msal Python {} Tester (Experimental)\n".format(msal.__version__))
cache_choice = _select_options([
{
"choice": "empty",
"desc": "Start with an empty token cache. Suitable for one-off tests.",
},
{
"choice": "reuse",
"desc": "Reuse the previous token cache {} (if any) "
"which was created during last test app exit. "
"Useful for testing acquire_token_silent() repeatedly".format(
_token_cache_filename),
},
],
option_renderer=lambda o: o["desc"],
header="What token cache state do you want to begin with?",
accept_nonempty_string=False)
if cache_choice["choice"] == "reuse" and os.path.exists(_token_cache_filename):
try:
global_cache.deserialize(open(_token_cache_filename, "r").read())
except IOError:
pass # Use empty token cache
chosen_app = _select_options([
{"client_id": _AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
{"client_id": _VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
Expand All @@ -207,6 +245,7 @@ def _main():
),
enable_broker_on_windows=enable_broker,
enable_pii_log=enable_pii_log,
token_cache=global_cache,
)
if enable_debug_log:
logging.basicConfig(level=logging.DEBUG)
Expand All @@ -215,6 +254,7 @@ def _main():
_acquire_token_silent,
_acquire_token_interactive,
_acquire_token_by_username_password,
_acquire_token_by_device_flow,
_acquire_ssh_cert_silently,
_acquire_ssh_cert_interactive,
_acquire_pop_token_interactive,
Expand Down
20 changes: 15 additions & 5 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache, _get_username
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
Expand Down Expand Up @@ -1104,6 +1104,7 @@ def _find_msal_accounts(self, environment):
"home_account_id": a.get("home_account_id"),
"environment": a.get("environment"),
"username": a.get("username"),
"account_source": a.get("account_source"),

# The following fields for backward compatibility, for now
"authority_type": a.get("authority_type"),
Expand Down Expand Up @@ -1398,7 +1399,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
return self._acquire_token_by_cloud_shell(scopes, data=data)

if self._enable_broker and account is not None:
if self._enable_broker and account and account.get("account_source") in (
_GRANT_TYPE_BROKER, # Broker successfully established this account previously.
None, # Unknown data from older MSAL. Broker might still work.
):
from .broker import _acquire_token_silently
response = _acquire_token_silently(
"https://{}/{}".format(self.authority.instance, self.authority.tenant),
Expand All @@ -1409,8 +1413,12 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
self._client_capabilities, claims_challenge),
correlation_id=correlation_id,
**data)
if response: # The broker provided a decisive outcome, so we use it
return self._process_broker_response(response, scopes, data)
if response: # Broker provides a decisive outcome
account_was_established_by_broker = account.get(
"account_source") == _GRANT_TYPE_BROKER
broker_attempt_succeeded_just_now = "error" not in response
if account_was_established_by_broker or broker_attempt_succeeded_just_now:
return self._process_broker_response(response, scopes, data)

if account:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
Expand Down Expand Up @@ -1441,6 +1449,8 @@ def _process_broker_response(self, response, scopes, data):
response=response,
data=data,
_account_id=response["_account_id"],
environment=self.authority.instance, # Be consistent with non-broker flows
grant_type=_GRANT_TYPE_BROKER, # A pseudo grant type for TokenCache to mark account_source as broker
))
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return _clean_up(response)
Expand Down Expand Up @@ -1628,7 +1638,7 @@ def acquire_token_by_username_password(
"""
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
if self._enable_broker:
if False: # Disabled, for now. It was if self._enable_broker:
from .broker import _signin_silently
response = _signin_silently(
"https://{}/{}".format(self.authority.instance, self.authority.tenant),
Expand Down
16 changes: 9 additions & 7 deletions msal/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,22 @@ def _convert_error(error, client_id):


def _read_account_by_id(account_id, correlation_id):
"""Return an instance of MSALRuntimeError or MSALRuntimeAccount, or None"""
"""Return an instance of MSALRuntimeAccount, or log error and return None"""
callback_data = _CallbackData()
pymsalruntime.read_account_by_id(
account_id,
correlation_id,
lambda result, callback_data=callback_data: callback_data.complete(result)
)
callback_data.signal.wait()
return (callback_data.result.get_error() or callback_data.result.get_account()
or None) # None happens when the account was not created by broker
error = callback_data.result.get_error()
if error:
logger.debug("read_account_by_id() error: %s", _convert_error(error, None))
return None
account = callback_data.result.get_account()
if account:
return account
return None # None happens when the account was not created by broker


def _convert_result(result, client_id, expected_token_type=None): # Mimic an on-the-wire response from AAD
Expand Down Expand Up @@ -196,8 +202,6 @@ def _acquire_token_silently(
# acquireTokenSilently is expected to fail. - Sam Wilson
correlation_id = correlation_id or _get_new_correlation_id()
account = _read_account_by_id(account_id, correlation_id)
if isinstance(account, pymsalruntime.MSALRuntimeError):
return _convert_error(account, client_id)
if account is None:
return
params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority)
Expand All @@ -221,8 +225,6 @@ def _acquire_token_silently(
def _signout_silently(client_id, account_id, correlation_id=None):
correlation_id = correlation_id or _get_new_correlation_id()
account = _read_account_by_id(account_id, correlation_id)
if isinstance(account, pymsalruntime.MSALRuntimeError):
return _convert_error(account, client_id)
if account is None:
return
callback_data = _CallbackData()
Expand Down
7 changes: 7 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from .authority import canonicalize
from .oauth2cli.oidc import decode_part, decode_id_token
from .oauth2cli.oauth2 import Client


logger = logging.getLogger(__name__)
_GRANT_TYPE_BROKER = "broker"

def is_subdict_of(small, big):
return dict(big, **small) == big
Expand Down Expand Up @@ -210,6 +212,11 @@ def __add(self, event, now=None):
else self.AuthorityType.MSSTS),
# "client_info": response.get("client_info"), # Optional
}
grant_types_that_establish_an_account = (
_GRANT_TYPE_BROKER, "authorization_code", "password",
Client.DEVICE_FLOW["GRANT_TYPE"])
if event.get("grant_type") in grant_types_that_establish_an_account:
account["account_source"] = event["grant_type"]
self.modify(self.CredentialType.ACCOUNT, account, account)

if id_token:
Expand Down
79 changes: 79 additions & 0 deletions tests/test_account_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
try:
from unittest.mock import patch
except:
from mock import patch
try:
import pymsalruntime
broker_available = True
except ImportError:
broker_available = False
import msal
from tests import unittest
from tests.test_token_cache import build_response
from tests.http_client import MinimalResponse


SCOPE = "scope_foo"
TOKEN_RESPONSE = build_response(
access_token="at",
uid="uid", utid="utid", # So that it will create an account
scope=SCOPE, refresh_token="rt", # So that non-broker's acquire_token_silent() would work
)

def _mock_post(url, headers=None, *args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps(TOKEN_RESPONSE))

@unittest.skipUnless(broker_available, "These test cases need pip install msal[broker]")
@patch("msal.broker._acquire_token_silently", return_value=dict(
TOKEN_RESPONSE, _account_id="placeholder"))
@patch.object(msal.authority, "tenant_discovery", return_value={
"authorization_endpoint": "https://contoso.com/placeholder",
"token_endpoint": "https://contoso.com/placeholder",
}) # Otherwise it would fail on OIDC discovery
class TestAccountSourceBehavior(unittest.TestCase):

def test_device_flow_and_its_silent_call_should_bypass_broker(self, _, mocked_broker_ats):
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=_mock_post)
self.assertEqual(result["token_source"], "identity_provider")

account = app.get_accounts()[0]
self.assertEqual(account["account_source"], "urn:ietf:params:oauth:grant-type:device_code")

result = app.acquire_token_silent_with_error(
[SCOPE], account, force_refresh=True, post=_mock_post)
mocked_broker_ats.assert_not_called()
self.assertEqual(result["token_source"], "identity_provider")

def test_ropc_flow_and_its_silent_call_should_bypass_broker(self, _, mocked_broker_ats):
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
with patch.object(app.authority, "user_realm_discovery", return_value={}):
result = app.acquire_token_by_username_password(
"username", "placeholder", [SCOPE], post=_mock_post)
self.assertEqual(result["token_source"], "identity_provider")

account = app.get_accounts()[0]
self.assertEqual(account["account_source"], "password")

result = app.acquire_token_silent_with_error(
[SCOPE], account, force_refresh=True, post=_mock_post)
mocked_broker_ats.assert_not_called()
self.assertEqual(result["token_source"], "identity_provider")

def test_interactive_flow_and_its_silent_call_should_invoke_broker(self, _, mocked_broker_ats):
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
with patch.object(app, "_acquire_token_interactive_via_broker", return_value=dict(
TOKEN_RESPONSE, _account_id="placeholder")):
result = app.acquire_token_interactive(
[SCOPE], parent_window_handle=app.CONSOLE_WINDOW_HANDLE)
self.assertEqual(result["token_source"], "broker")

account = app.get_accounts()[0]
self.assertEqual(account["account_source"], "broker")

result = app.acquire_token_silent_with_error(
[SCOPE], account, force_refresh=True, post=_mock_post)
mocked_broker_ats.assert_called_once()
self.assertEqual(result["token_source"], "broker")

0 comments on commit 9c124b5

Please sign in to comment.