Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acquire_token_silent() shall not invoke broker if the account was not established by broker #569

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions msal/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

Usage 1: Run it on the fly.
python -m msal
Note: We choose to not define a console script to avoid name conflict.

Usage 2: Build an all-in-one executable file for bug bash.
shiv -e msal.__main__._main -o msaltest-on-os-name.pyz .
Note: We choose to not define a console script to avoid name conflict.
"""
import base64, getpass, json, logging, sys, msal
import base64, getpass, json, logging, sys, os, atexit, msal

_token_cache_filename = "msal_cache.bin"
global_cache = msal.SerializableTokenCache()
atexit.register(lambda:
open(_token_cache_filename, "w").write(global_cache.serialize())
# Hint: The following optional line persists only when state changed
if global_cache.has_state_changed else None
)

_AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
_VISUAL_STUDIO = "04f0c124-f2bc-4f59-8241-bf6df9866bbd"
Expand Down Expand Up @@ -66,7 +74,7 @@ def _select_account(app):
if accounts:
return _select_options(
accounts,
option_renderer=lambda a: a["username"],
option_renderer=lambda a: "{}, came from {}".format(a["username"], a["account_source"]),
header="Account(s) already signed in inside MSAL Python:",
)
else:
Expand All @@ -76,7 +84,7 @@ def _acquire_token_silent(app):
"""acquire_token_silent() - with an account already signed into MSAL Python."""
account = _select_account(app)
if account:
print_json(app.acquire_token_silent(
print_json(app.acquire_token_silent_with_error(
_input_scopes(),
account=account,
force_refresh=_input_boolean("Bypass MSAL Python's token cache?"),
Expand Down Expand Up @@ -122,6 +130,15 @@ def _acquire_token_by_username_password(app):
print_json(app.acquire_token_by_username_password(
_input("username: "), getpass.getpass("password: "), scopes=_input_scopes()))

def _acquire_token_by_device_flow(app):
"""acquire_token_by_device_flow() - Note that this one does not go through broker"""
flow = app.initiate_device_flow(scopes=_input_scopes())
print(flow["message"])
sys.stdout.flush() # Some terminal needs this to ensure the message is shown
input("After you completed the step above, press ENTER in this console to continue...")
result = app.acquire_token_by_device_flow(flow) # By default it will block
print_json(result)

_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
_SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
_SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
Expand Down Expand Up @@ -182,6 +199,27 @@ def _exit(app):

def _main():
print("Welcome to the Msal Python {} Tester (Experimental)\n".format(msal.__version__))
cache_choice = _select_options([
{
"choice": "empty",
"desc": "Start with an empty token cache. Suitable for one-off tests.",
},
{
"choice": "reuse",
"desc": "Reuse the previous token cache {} (if any) "
"which was created during last test app exit. "
"Useful for testing acquire_token_silent() repeatedly".format(
_token_cache_filename),
},
],
option_renderer=lambda o: o["desc"],
header="What token cache state do you want to begin with?",
accept_nonempty_string=False)
if cache_choice["choice"] == "reuse" and os.path.exists(_token_cache_filename):
try:
global_cache.deserialize(open(_token_cache_filename, "r").read())
except IOError:
pass # Use empty token cache
chosen_app = _select_options([
{"client_id": _AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
{"client_id": _VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
Expand All @@ -207,6 +245,7 @@ def _main():
),
enable_broker_on_windows=enable_broker,
enable_pii_log=enable_pii_log,
token_cache=global_cache,
)
if enable_debug_log:
logging.basicConfig(level=logging.DEBUG)
Expand All @@ -215,6 +254,7 @@ def _main():
_acquire_token_silent,
_acquire_token_interactive,
_acquire_token_by_username_password,
_acquire_token_by_device_flow,
_acquire_ssh_cert_silently,
_acquire_ssh_cert_interactive,
_acquire_pop_token_interactive,
Expand Down
20 changes: 15 additions & 5 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache, _get_username
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
Expand Down Expand Up @@ -1104,6 +1104,7 @@ def _find_msal_accounts(self, environment):
"home_account_id": a.get("home_account_id"),
"environment": a.get("environment"),
"username": a.get("username"),
"account_source": a.get("account_source"),

# The following fields for backward compatibility, for now
"authority_type": a.get("authority_type"),
Expand Down Expand Up @@ -1398,7 +1399,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
return self._acquire_token_by_cloud_shell(scopes, data=data)

if self._enable_broker and account is not None:
if self._enable_broker and account and account.get("account_source") in (
_GRANT_TYPE_BROKER, # Broker successfully established this account previously.
None, # Unknown data from older MSAL. Broker might still work.
):
rayluo marked this conversation as resolved.
Show resolved Hide resolved
from .broker import _acquire_token_silently
response = _acquire_token_silently(
"https://{}/{}".format(self.authority.instance, self.authority.tenant),
Expand All @@ -1409,8 +1413,12 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
self._client_capabilities, claims_challenge),
correlation_id=correlation_id,
**data)
if response: # The broker provided a decisive outcome, so we use it
return self._process_broker_response(response, scopes, data)
if response: # Broker provides a decisive outcome
account_was_established_by_broker = account.get(
"account_source") == _GRANT_TYPE_BROKER
broker_attempt_succeeded_just_now = "error" not in response
if account_was_established_by_broker or broker_attempt_succeeded_just_now:
return self._process_broker_response(response, scopes, data)

if account:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
Expand Down Expand Up @@ -1441,6 +1449,8 @@ def _process_broker_response(self, response, scopes, data):
response=response,
data=data,
_account_id=response["_account_id"],
environment=self.authority.instance, # Be consistent with non-broker flows
grant_type=_GRANT_TYPE_BROKER, # A pseudo grant type for TokenCache to mark account_source as broker
))
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return _clean_up(response)
Expand Down Expand Up @@ -1628,7 +1638,7 @@ def acquire_token_by_username_password(
"""
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
if self._enable_broker:
if False: # Disabled, for now. It was if self._enable_broker:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI. In v1, we also expose broker for InteractiveBrowserCred, there is no support for username password for broker in Python identity too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI v2: We reverted course again, back to the original.

CC: @jiasli

Note: New comments, if any, shall be posted into that new feature request 702.

from .broker import _signin_silently
response = _signin_silently(
"https://{}/{}".format(self.authority.instance, self.authority.tenant),
Expand Down
16 changes: 9 additions & 7 deletions msal/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,22 @@ def _convert_error(error, client_id):


def _read_account_by_id(account_id, correlation_id):
"""Return an instance of MSALRuntimeError or MSALRuntimeAccount, or None"""
"""Return an instance of MSALRuntimeAccount, or log error and return None"""
callback_data = _CallbackData()
pymsalruntime.read_account_by_id(
account_id,
correlation_id,
lambda result, callback_data=callback_data: callback_data.complete(result)
)
callback_data.signal.wait()
return (callback_data.result.get_error() or callback_data.result.get_account()
or None) # None happens when the account was not created by broker
error = callback_data.result.get_error()
if error:
logger.debug("read_account_by_id() error: %s", _convert_error(error, None))
return None
account = callback_data.result.get_account()
if account:
return account
return None # None happens when the account was not created by broker


def _convert_result(result, client_id, expected_token_type=None): # Mimic an on-the-wire response from AAD
Expand Down Expand Up @@ -196,8 +202,6 @@ def _acquire_token_silently(
# acquireTokenSilently is expected to fail. - Sam Wilson
correlation_id = correlation_id or _get_new_correlation_id()
account = _read_account_by_id(account_id, correlation_id)
if isinstance(account, pymsalruntime.MSALRuntimeError):
return _convert_error(account, client_id)
if account is None:
return
params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority)
Expand All @@ -221,8 +225,6 @@ def _acquire_token_silently(
def _signout_silently(client_id, account_id, correlation_id=None):
correlation_id = correlation_id or _get_new_correlation_id()
account = _read_account_by_id(account_id, correlation_id)
if isinstance(account, pymsalruntime.MSALRuntimeError):
return _convert_error(account, client_id)
if account is None:
return
callback_data = _CallbackData()
Expand Down
7 changes: 7 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from .authority import canonicalize
from .oauth2cli.oidc import decode_part, decode_id_token
from .oauth2cli.oauth2 import Client


logger = logging.getLogger(__name__)
_GRANT_TYPE_BROKER = "broker"

def is_subdict_of(small, big):
return dict(big, **small) == big
Expand Down Expand Up @@ -210,6 +212,11 @@ def __add(self, event, now=None):
else self.AuthorityType.MSSTS),
# "client_info": response.get("client_info"), # Optional
}
grant_types_that_establish_an_account = (
_GRANT_TYPE_BROKER, "authorization_code", "password",
Client.DEVICE_FLOW["GRANT_TYPE"])
if event.get("grant_type") in grant_types_that_establish_an_account:
account["account_source"] = event["grant_type"]
self.modify(self.CredentialType.ACCOUNT, account, account)

if id_token:
Expand Down
79 changes: 79 additions & 0 deletions tests/test_account_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
try:
from unittest.mock import patch
except:
from mock import patch
try:
import pymsalruntime
broker_available = True
except ImportError:
broker_available = False
import msal
from tests import unittest
from tests.test_token_cache import build_response
from tests.http_client import MinimalResponse


SCOPE = "scope_foo"
TOKEN_RESPONSE = build_response(
access_token="at",
uid="uid", utid="utid", # So that it will create an account
scope=SCOPE, refresh_token="rt", # So that non-broker's acquire_token_silent() would work
)

def _mock_post(url, headers=None, *args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps(TOKEN_RESPONSE))

@unittest.skipUnless(broker_available, "These test cases need pip install msal[broker]")
@patch("msal.broker._acquire_token_silently", return_value=dict(
TOKEN_RESPONSE, _account_id="placeholder"))
@patch.object(msal.authority, "tenant_discovery", return_value={
"authorization_endpoint": "https://contoso.com/placeholder",
"token_endpoint": "https://contoso.com/placeholder",
}) # Otherwise it would fail on OIDC discovery
class TestAccountSourceBehavior(unittest.TestCase):

def test_device_flow_and_its_silent_call_should_bypass_broker(self, _, mocked_broker_ats):
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=_mock_post)
self.assertEqual(result["token_source"], "identity_provider")

account = app.get_accounts()[0]
self.assertEqual(account["account_source"], "urn:ietf:params:oauth:grant-type:device_code")

result = app.acquire_token_silent_with_error(
[SCOPE], account, force_refresh=True, post=_mock_post)
mocked_broker_ats.assert_not_called()
self.assertEqual(result["token_source"], "identity_provider")

def test_ropc_flow_and_its_silent_call_should_bypass_broker(self, _, mocked_broker_ats):
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
with patch.object(app.authority, "user_realm_discovery", return_value={}):
result = app.acquire_token_by_username_password(
"username", "placeholder", [SCOPE], post=_mock_post)
self.assertEqual(result["token_source"], "identity_provider")

account = app.get_accounts()[0]
self.assertEqual(account["account_source"], "password")

result = app.acquire_token_silent_with_error(
[SCOPE], account, force_refresh=True, post=_mock_post)
mocked_broker_ats.assert_not_called()
self.assertEqual(result["token_source"], "identity_provider")

def test_interactive_flow_and_its_silent_call_should_invoke_broker(self, _, mocked_broker_ats):
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
with patch.object(app, "_acquire_token_interactive_via_broker", return_value=dict(
TOKEN_RESPONSE, _account_id="placeholder")):
result = app.acquire_token_interactive(
[SCOPE], parent_window_handle=app.CONSOLE_WINDOW_HANDLE)
self.assertEqual(result["token_source"], "broker")

account = app.get_accounts()[0]
self.assertEqual(account["account_source"], "broker")

result = app.acquire_token_silent_with_error(
[SCOPE], account, force_refresh=True, post=_mock_post)
mocked_broker_ats.assert_called_once()
self.assertEqual(result["token_source"], "broker")