Skip to content

Commit

Permalink
Only invoke broker for selected flows (grants)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Nov 2, 2023
1 parent 88c4bf8 commit f809634
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 15 deletions.
48 changes: 44 additions & 4 deletions msal/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
Usage 1: Run it on the fly.
python -m msal
Note: We choose to not define a console script to avoid name conflict.
Usage 2: Build an all-in-one executable file for bug bash.
shiv -e msal.__main__._main -o msaltest-on-os-name.pyz .
Note: We choose to not define a console script to avoid name conflict.
"""
import base64, getpass, json, logging, sys, msal
import base64, getpass, json, logging, sys, os, atexit, msal

_token_cache_filename = "msal_cache.bin"
global_cache = msal.SerializableTokenCache()
atexit.register(lambda:
open(_token_cache_filename, "w").write(global_cache.serialize())
# Hint: The following optional line persists only when state changed
if global_cache.has_state_changed else None
)

_AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
_VISUAL_STUDIO = "04f0c124-f2bc-4f59-8241-bf6df9866bbd"
Expand Down Expand Up @@ -66,7 +74,7 @@ def _select_account(app):
if accounts:
return _select_options(
accounts,
option_renderer=lambda a: a["username"],
option_renderer=lambda a: "{}, came from {}".format(a["username"], a["account_source"]),
header="Account(s) already signed in inside MSAL Python:",
)
else:
Expand All @@ -76,7 +84,7 @@ def _acquire_token_silent(app):
"""acquire_token_silent() - with an account already signed into MSAL Python."""
account = _select_account(app)
if account:
print_json(app.acquire_token_silent(
print_json(app.acquire_token_silent_with_error(
_input_scopes(),
account=account,
force_refresh=_input_boolean("Bypass MSAL Python's token cache?"),
Expand Down Expand Up @@ -122,6 +130,15 @@ def _acquire_token_by_username_password(app):
print_json(app.acquire_token_by_username_password(
_input("username: "), getpass.getpass("password: "), scopes=_input_scopes()))

def _acquire_token_by_device_flow(app):
"""acquire_token_by_device_flow() - Note that this one does not go through broker"""
flow = app.initiate_device_flow(scopes=_input_scopes())
print(flow["message"])
sys.stdout.flush() # Some terminal needs this to ensure the message is shown
input("After you completed the step above, press ENTER in this console to continue...")
result = app.acquire_token_by_device_flow(flow) # By default it will block
print_json(result)

_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
_SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
_SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
Expand Down Expand Up @@ -182,6 +199,27 @@ def _exit(app):

def _main():
print("Welcome to the Msal Python {} Tester (Experimental)\n".format(msal.__version__))
cache_choice = _select_options([
{
"choice": "empty",
"desc": "Start with an empty token cache. Suitable for one-off tests.",
},
{
"choice": "reuse",
"desc": "Reuse the previous token cache {} (if any) "
"which was created during last test app exit. "
"Useful for testing acquire_token_silent() repeatedly".format(
_token_cache_filename),
},
],
option_renderer=lambda o: o["desc"],
header="What token cache state do you want to begin with?",
accept_nonempty_string=False)
if cache_choice["choice"] == "reuse" and os.path.exists(_token_cache_filename):
try:
global_cache.deserialize(open(_token_cache_filename, "r").read())
except IOError:
pass # Use empty token cache
chosen_app = _select_options([
{"client_id": _AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
{"client_id": _VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
Expand All @@ -207,6 +245,7 @@ def _main():
),
enable_broker_on_windows=enable_broker,
enable_pii_log=enable_pii_log,
token_cache=global_cache,
)
if enable_debug_log:
logging.basicConfig(level=logging.DEBUG)
Expand All @@ -215,6 +254,7 @@ def _main():
_acquire_token_silent,
_acquire_token_interactive,
_acquire_token_by_username_password,
_acquire_token_by_device_flow,
_acquire_ssh_cert_silently,
_acquire_ssh_cert_interactive,
_acquire_pop_token_interactive,
Expand Down
17 changes: 13 additions & 4 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache, _get_username
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
Expand Down Expand Up @@ -1104,6 +1104,7 @@ def _find_msal_accounts(self, environment):
"home_account_id": a.get("home_account_id"),
"environment": a.get("environment"),
"username": a.get("username"),
"account_source": a.get("account_source"),

# The following fields for backward compatibility, for now
"authority_type": a.get("authority_type"),
Expand Down Expand Up @@ -1398,7 +1399,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
return self._acquire_token_by_cloud_shell(scopes, data=data)

if self._enable_broker and account is not None:
if self._enable_broker and account and account.get("account_source") in (
_GRANT_TYPE_BROKER, # Broker successfully established this account previously.
None, # Unknown data from older MSAL. Broker might still work.
):
from .broker import _acquire_token_silently
response = _acquire_token_silently(
"https://{}/{}".format(self.authority.instance, self.authority.tenant),
Expand All @@ -1409,8 +1413,12 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
self._client_capabilities, claims_challenge),
correlation_id=correlation_id,
**data)
if response: # The broker provided a decisive outcome, so we use it
return self._process_broker_response(response, scopes, data)
if response: # Broker provides a decisive outcome
account_was_established_by_broker = account.get(
"account_source") == _GRANT_TYPE_BROKER
broker_attempt_succeeded_just_now = "error" not in response
if account_was_established_by_broker or broker_attempt_succeeded_just_now:
return self._process_broker_response(response, scopes, data)

if account:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
Expand Down Expand Up @@ -1441,6 +1449,7 @@ def _process_broker_response(self, response, scopes, data):
response=response,
data=data,
_account_id=response["_account_id"],
grant_type=_GRANT_TYPE_BROKER, # A pseudo grant type for TokenCache to mark account_source as broker
))
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return _clean_up(response)
Expand Down
16 changes: 9 additions & 7 deletions msal/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,22 @@ def _convert_error(error, client_id):


def _read_account_by_id(account_id, correlation_id):
"""Return an instance of MSALRuntimeError or MSALRuntimeAccount, or None"""
"""Return an instance of MSALRuntimeAccount, or log error and return None"""
callback_data = _CallbackData()
pymsalruntime.read_account_by_id(
account_id,
correlation_id,
lambda result, callback_data=callback_data: callback_data.complete(result)
)
callback_data.signal.wait()
return (callback_data.result.get_error() or callback_data.result.get_account()
or None) # None happens when the account was not created by broker
error = callback_data.result.get_error()
if error:
logger.debug("read_account_by_id() error: %s", _convert_error(error, None))
return None
account = callback_data.result.get_account()
if account:
return account
return None # None happens when the account was not created by broker


def _convert_result(result, client_id, expected_token_type=None): # Mimic an on-the-wire response from AAD
Expand Down Expand Up @@ -196,8 +202,6 @@ def _acquire_token_silently(
# acquireTokenSilently is expected to fail. - Sam Wilson
correlation_id = correlation_id or _get_new_correlation_id()
account = _read_account_by_id(account_id, correlation_id)
if isinstance(account, pymsalruntime.MSALRuntimeError):
return _convert_error(account, client_id)
if account is None:
return
params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority)
Expand All @@ -221,8 +225,6 @@ def _acquire_token_silently(
def _signout_silently(client_id, account_id, correlation_id=None):
correlation_id = correlation_id or _get_new_correlation_id()
account = _read_account_by_id(account_id, correlation_id)
if isinstance(account, pymsalruntime.MSALRuntimeError):
return _convert_error(account, client_id)
if account is None:
return
callback_data = _CallbackData()
Expand Down
7 changes: 7 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from .authority import canonicalize
from .oauth2cli.oidc import decode_part, decode_id_token
from .oauth2cli.oauth2 import Client


logger = logging.getLogger(__name__)
_GRANT_TYPE_BROKER = "broker"

def is_subdict_of(small, big):
return dict(big, **small) == big
Expand Down Expand Up @@ -210,6 +212,11 @@ def __add(self, event, now=None):
else self.AuthorityType.MSSTS),
# "client_info": response.get("client_info"), # Optional
}
grant_types_that_establish_an_account = (
_GRANT_TYPE_BROKER, "authorization_code", "password",
Client.DEVICE_FLOW["GRANT_TYPE"])
if event.get("grant_type") in grant_types_that_establish_an_account:
account["account_source"] = event["grant_type"]
self.modify(self.CredentialType.ACCOUNT, account, account)

if id_token:
Expand Down

0 comments on commit f809634

Please sign in to comment.