diff --git a/.gitignore b/.gitignore index 36b43713..58868119 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ docs/_build/ # The test configuration file(s) could potentially contain credentials tests/config.json +# Token Cache files +msal_cache.bin .env .perf.baseline diff --git a/msal/__main__.py b/msal/__main__.py index 1c09f868..0c6c59f7 100644 --- a/msal/__main__.py +++ b/msal/__main__.py @@ -299,6 +299,7 @@ def _main(): authority=authority, instance_discovery=instance_discovery, enable_broker_on_windows=enable_broker, + enable_broker_on_mac=enable_broker, enable_pii_log=enable_pii_log, token_cache=global_cache, ) if not is_cca else msal.ConfidentialClientApplication( diff --git a/msal/application.py b/msal/application.py index b3c07a47..260d80e0 100644 --- a/msal/application.py +++ b/msal/application.py @@ -5,6 +5,7 @@ import sys import warnings from threading import Lock +from typing import Optional # Needed in Python 3.7 & 3.8 import os from .oauth2cli import Client, JwtAssertionCreator @@ -21,11 +22,16 @@ # The __init__.py will import this. Not the other way around. -__version__ = "1.30.0" # When releasing, also check and bump our dependencies's versions if needed +__version__ = "1.31.0" # When releasing, also check and bump our dependencies's versions if needed logger = logging.getLogger(__name__) _AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL" +def _init_broker(enable_pii_log): # Make it a function to allow mocking + from . import broker # Trigger Broker's initialization, lazily + if enable_pii_log: + broker._enable_pii_log() + def extract_certs(public_cert_content): # Parses raw public certificate file contents and returns a list of strings # Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())} @@ -189,6 +195,21 @@ def obtain_token_by_username_password(self, username, password, **kwargs): username, password, headers=headers, **kwargs) +def _msal_extension_check(): + # Can't run this in module or class level otherwise you'll get circular import error + try: + from msal_extensions import __version__ as v + major, minor, _ = v.split(".", maxsplit=3) + if not (int(major) >= 1 and int(minor) >= 2): + warnings.warn( + "Please upgrade msal-extensions. " + "Only msal-extensions 1.2+ can work with msal 1.30+") + except ImportError: + pass # The optional msal_extensions is not installed. Business as usual. + except ValueError: + logger.exception(f"msal_extensions version {v} not in major.minor.patch format") + + class ClientApplication(object): """You do not usually directly use this class. Use its subclasses instead: :class:`PublicClientApplication` and :class:`ConfidentialClientApplication`. @@ -205,6 +226,7 @@ class ClientApplication(object): REMOVE_ACCOUNT_ID = "903" ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect" + DISABLE_MSAL_FORCE_REGION = False # Used in azure_region to disable MSAL_FORCE_REGION behavior _TOKEN_SOURCE = "token_source" _TOKEN_SOURCE_IDP = "identity_provider" _TOKEN_SOURCE_CACHE = "cache" @@ -411,9 +433,11 @@ def __init__( (STS) what this client is capable for, so STS can decide to turn on certain features. For example, if client is capable to handle *claims challenge*, - STS can then issue CAE access tokens to resources - knowing when the resource emits *claims challenge* - the client will be capable to handle. + STS may issue + `Continuous Access Evaluation (CAE) `_ + access tokens to resources, + knowing that when the resource emits a *claims challenge* + the client will be able to handle those challenges. Implementation details: Client capability is implemented using "claims" parameter on the wire, @@ -426,11 +450,14 @@ def __init__( Instructs MSAL to use the Entra regional token service. This legacy feature is only available to first-party applications. Only ``acquire_token_for_client()`` is supported. - Supports 3 values: + Supports 4 values: - ``azure_region=None`` - meaning no region is used. This is the default value. - ``azure_region="some_region"`` - meaning the specified region is used. - ``azure_region=True`` - meaning MSAL will try to auto-detect the region. This is not recommended. + 1. ``azure_region=None`` - This default value means no region is configured. + MSAL will use the region defined in env var ``MSAL_FORCE_REGION``. + 2. ``azure_region="some_region"`` - meaning the specified region is used. + 3. ``azure_region=True`` - meaning + MSAL will try to auto-detect the region. This is not recommended. + 4. ``azure_region=False`` - meaning MSAL will use no region. .. note:: Region auto-discovery has been tested on VMs and on Azure Functions. It is unreliable. @@ -608,7 +635,10 @@ def __init__( except ValueError: # Those are explicit authority validation errors raise except Exception: # The rest are typically connection errors - if validate_authority and azure_region and not oidc_authority: + if validate_authority and not oidc_authority and ( + azure_region # Opted in to use region + or (azure_region is None and os.getenv("MSAL_FORCE_REGION")) # Will use region + ): # Since caller opts in to use region, here we tolerate connection # errors happened during authority validation at non-region endpoint self.authority = Authority( @@ -628,6 +658,8 @@ def __init__( self.authority_groups = None self._telemetry_buffer = {} self._telemetry_lock = Lock() + _msal_extension_check() + def _decide_broker(self, allow_broker, enable_pii_log): is_confidential_app = self.client_credential or isinstance( @@ -638,20 +670,28 @@ def _decide_broker(self, allow_broker, enable_pii_log): if allow_broker: warnings.warn( "allow_broker is deprecated. " - "Please use PublicClientApplication(..., enable_broker_on_windows=True)", + "Please use PublicClientApplication(..., " + "enable_broker_on_windows=True, " + "enable_broker_on_mac=...)", DeprecationWarning) - self._enable_broker = self._enable_broker or ( + opted_in_for_broker = ( + self._enable_broker # True means Opted-in from PCA + or ( # When we started the broker project on Windows platform, # the allow_broker was meant to be cross-platform. Now we realize # that other platforms have different redirect_uri requirements, # so the old allow_broker is deprecated and will only for Windows. allow_broker and sys.platform == "win32") - if (self._enable_broker and not is_confidential_app - and not self.authority.is_adfs and not self.authority._is_b2c): + ) + self._enable_broker = ( # This same variable will also store the state + opted_in_for_broker + and not is_confidential_app + and not self.authority.is_adfs + and not self.authority._is_b2c + ) + if self._enable_broker: try: - from . import broker # Trigger Broker's initialization - if enable_pii_log: - broker._enable_pii_log() + _init_broker(enable_pii_log) except RuntimeError: self._enable_broker = False logger.exception( @@ -692,9 +732,11 @@ def _build_telemetry_context( self._telemetry_buffer, self._telemetry_lock, api_id, correlation_id=correlation_id, refresh_reason=refresh_reason) - def _get_regional_authority(self, central_authority): - if not self._region_configured: # User did not opt-in to ESTS-R + def _get_regional_authority(self, central_authority) -> Optional[Authority]: + if self._region_configured is False: # User opts out of ESTS-R return None # Short circuit to completely bypass region detection + if self._region_configured is None: # User did not make an ESTS-R choice + self._region_configured = os.getenv("MSAL_FORCE_REGION") or None self._region_detected = self._region_detected or _detect_region( self.http_client if self._region_configured is not None else None) if (self._region_configured != self.ATTEMPT_REGION_DISCOVERY @@ -1879,7 +1921,7 @@ def __init__(self, client_id, client_credential=None, **kwargs): .. note:: - You may set enable_broker_on_windows to True. + You may set enable_broker_on_windows and/or enable_broker_on_mac to True. **What is a broker, and why use it?** @@ -1905,9 +1947,11 @@ def __init__(self, client_id, client_credential=None, **kwargs): * ``ms-appx-web://Microsoft.AAD.BrokerPlugin/your_client_id`` if your app is expected to run on Windows 10+ + * ``msauth.com.msauth.unsignedapp://auth`` + if your app is expected to run on Mac 2. installed broker dependency, - e.g. ``pip install msal[broker]>=1.25,<2``. + e.g. ``pip install msal[broker]>=1.31,<2``. 3. tested with ``acquire_token_interactive()`` and ``acquire_token_silent()``. @@ -1939,12 +1983,21 @@ def __init__(self, client_id, client_credential=None, **kwargs): This parameter defaults to None, which means MSAL will not utilize a broker. New in MSAL Python 1.25.0. + + :param boolean enable_broker_on_mac: + This setting is only effective if your app is running on Mac. + This parameter defaults to None, which means MSAL will not utilize a broker. + + New in MSAL Python 1.31.0. """ if client_credential is not None: raise ValueError("Public Client should not possess credentials") # Using kwargs notation for now. We will switch to keyword-only arguments. enable_broker_on_windows = kwargs.pop("enable_broker_on_windows", False) - self._enable_broker = enable_broker_on_windows and sys.platform == "win32" + enable_broker_on_mac = kwargs.pop("enable_broker_on_mac", False) + self._enable_broker = bool( + enable_broker_on_windows and sys.platform == "win32" + or enable_broker_on_mac and sys.platform == "darwin") super(PublicClientApplication, self).__init__( client_id, client_credential=None, **kwargs) @@ -2022,14 +2075,22 @@ def acquire_token_interactive( New in version 1.15. :param int parent_window_handle: - Required if your app is running on Windows and opted in to use broker. + OPTIONAL. + + * If your app does not opt in to use broker, + you do not need to provide a ``parent_window_handle`` here. + + * If your app opts in to use broker, + ``parent_window_handle`` is required. - If your app is a GUI app, - you are recommended to also provide its window handle, - so that the sign in UI window will properly pop up on top of your window. + - If your app is a GUI app running on Windows or Mac system, + you are required to also provide its window handle, + so that the sign-in window will pop up on top of your window. + - If your app is a console app running on Windows or Mac system, + you can use a placeholder + ``PublicClientApplication.CONSOLE_WINDOW_HANDLE``. - If your app is a console app (most Python scripts are console apps), - you can use a placeholder value ``msal.PublicClientApplication.CONSOLE_WINDOW_HANDLE``. + Most Python scripts are console apps. New in version 1.20.0. diff --git a/msal/broker.py b/msal/broker.py index 82bc3d87..775475a7 100644 --- a/msal/broker.py +++ b/msal/broker.py @@ -1,9 +1,9 @@ """This module is an adaptor to the underlying broker. It relies on PyMsalRuntime which is the package providing broker's functionality. """ -from threading import Event import json import logging +import sys import time import uuid @@ -23,7 +23,15 @@ except (ImportError, AttributeError): # AttributeError happens when a prior pymsalruntime uninstallation somehow leaved an empty folder behind # PyMsalRuntime currently supports these Windows versions, listed in this MSFT internal link # https://github.com/AzureAD/microsoft-authentication-library-for-cpp/pull/2406/files - raise ImportError('You need to install dependency by: pip install "msal[broker]>=1.20,<2"') + min_ver = { + "win32": "1.20", + "darwin": "1.31", + }.get(sys.platform) + if min_ver: + raise ImportError( + f'You must install dependency by: pip install "msal[broker]>={min_ver},<2"') + else: # Unsupported platform + raise ImportError("Dependency pymsalruntime unavailable on current platform") # It could throw RuntimeError when running on ancient versions of Windows @@ -35,14 +43,12 @@ class TokenTypeError(ValueError): pass -class _CallbackData: - def __init__(self): - self.signal = Event() - self.result = None - - def complete(self, result): - self.signal.set() - self.result = result +_redirect_uri_on_mac = "msauth.com.msauth.unsignedapp://auth" # Note: + # On Mac, the native Python has a team_id which links to bundle id + # com.apple.python3 however it won't give Python scripts better security. + # Besides, the homebrew-installed Pythons have no team_id + # so they have to use a generic placeholder anyway. + # The v-team chose to combine two situations into using same placeholder. def _convert_error(error, client_id): @@ -52,8 +58,9 @@ def _convert_error(error, client_id): or "AADSTS7000218" in context # This "request body must contain ... client_secret" is just a symptom of current app has no WAM redirect_uri ): raise RedirectUriError( # This would be seen by either the app developer or end user - "MsalRuntime won't work unless this one more redirect_uri is registered to current app: " - "ms-appx-web://Microsoft.AAD.BrokerPlugin/{}".format(client_id)) + "MsalRuntime needs the current app to register these redirect_uri " + "(1) ms-appx-web://Microsoft.AAD.BrokerPlugin/{} (2) {}".format( + client_id, _redirect_uri_on_mac)) # OTOH, AAD would emit other errors when other error handling branch was hit first, # so, the AADSTS50011/RedirectUriError is not guaranteed to happen. return { @@ -70,8 +77,8 @@ def _convert_error(error, client_id): def _read_account_by_id(account_id, correlation_id): - """Return an instance of MSALRuntimeAccount, or log error and return None""" - callback_data = _CallbackData() + """Return an instance of MSALRuntimeError or MSALRuntimeAccount, or None""" + callback_data = pymsalruntime.CallbackData() pymsalruntime.read_account_by_id( account_id, correlation_id, @@ -142,7 +149,7 @@ def _signin_silently( params.set_pop_params( auth_scheme._http_method, auth_scheme._url.netloc, auth_scheme._url.path, auth_scheme._nonce) - callback_data = _CallbackData() + callback_data = pymsalruntime.CallbackData() for k, v in kwargs.items(): # This can be used to support domain_hint, max_age, etc. if v is not None: params.set_additional_parameter(k, str(v)) @@ -169,9 +176,12 @@ def _signin_interactively( **kwargs): params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority) params.set_requested_scopes(scopes) - params.set_redirect_uri("https://login.microsoftonline.com/common/oauth2/nativeclient") - # This default redirect_uri value is not currently used by the broker + params.set_redirect_uri( + _redirect_uri_on_mac if sys.platform == "darwin" else + "https://login.microsoftonline.com/common/oauth2/nativeclient" + # This default redirect_uri value is not currently used by WAM # but it is required by the MSAL.cpp to be set to a non-empty valid URI. + ) if prompt: if prompt == "select_account": if login_hint: @@ -198,7 +208,7 @@ def _signin_interactively( params.set_additional_parameter(k, str(v)) if claims: params.set_decoded_claims(claims) - callback_data = _CallbackData() + callback_data = pymsalruntime.CallbackData(is_interactive=True) pymsalruntime.signin_interactively( parent_window_handle or pymsalruntime.get_console_window() or pymsalruntime.get_desktop_window(), # Since pymsalruntime 0.2+ params, @@ -231,7 +241,7 @@ def _acquire_token_silently( for k, v in kwargs.items(): # This can be used to support domain_hint, max_age, etc. if v is not None: params.set_additional_parameter(k, str(v)) - callback_data = _CallbackData() + callback_data = pymsalruntime.CallbackData() pymsalruntime.acquire_token_silently( params, correlation_id, @@ -247,7 +257,7 @@ def _signout_silently(client_id, account_id, correlation_id=None): account = _read_account_by_id(account_id, correlation_id) if account is None: return - callback_data = _CallbackData() + callback_data = pymsalruntime.CallbackData() pymsalruntime.signout_silently( # New in PyMsalRuntime 0.7 client_id, correlation_id, diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 755c9bd8..067f9ff1 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -10,7 +10,7 @@ import time from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ -from typing import Union # Needed in Python 3.7 & 3.8 +from typing import Optional, Union # Needed in Python 3.7 & 3.8 from .token_cache import TokenCache from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser @@ -40,14 +40,15 @@ class ManagedIdentity(UserDict): _types_mapping = { # Maps type name in configuration to type name on wire CLIENT_ID: "client_id", - RESOURCE_ID: "mi_res_id", + RESOURCE_ID: "msi_res_id", # VM's IMDS prefers msi_res_id https://github.com/Azure/azure-rest-api-specs/blob/dba6ed1f03bda88ac6884c0a883246446cc72495/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2018-10-01/imds.json#L233-L239 OBJECT_ID: "object_id", } @classmethod def is_managed_identity(cls, unknown): - return isinstance(unknown, ManagedIdentity) or ( - isinstance(unknown, dict) and cls.ID_TYPE in unknown) + return (isinstance(unknown, ManagedIdentity) + or cls.is_system_assigned(unknown) + or cls.is_user_assigned(unknown)) @classmethod def is_system_assigned(cls, unknown): @@ -133,6 +134,23 @@ class ManagedIdentityClient(object): It also provides token cache support. + .. admonition:: Special case when your local development wants to use a managed identity on Azure VM. + + By setting the environment variable ``MSAL_MANAGED_IDENTITY_ENDPOINT`` + you override the default identity URL used in MSAL's Azure VM managed identity + code path. + + This is useful during local development where it may be desirable to + utilise the credentials assigned to an actual VM instance via SSH tunnelling. + + For example, if you create your SSH tunnel this way (assuming your VM is on ``192.0.2.1``):: + + ssh -L 8000:169.254.169.254:80 192.0.2.1 + + Then your code could run locally using:: + + env MSAL_MANAGED_IDENTITY_ENDPOINT=http://localhost:8000/metadata/identity/oauth2/token python your_script.py + .. note:: Cloud Shell support is NOT implemented in this class. @@ -145,6 +163,9 @@ class ManagedIdentityClient(object): not a token with application permissions for an app. """ __instance, _tenant = None, "managed_identity" # Placeholders + _TOKEN_SOURCE = "token_source" + _TOKEN_SOURCE_IDP = "identity_provider" + _TOKEN_SOURCE_CACHE = "cache" def __init__( self, @@ -214,6 +235,9 @@ def __init__( ) token = client.acquire_token_for_client("resource") """ + if not ManagedIdentity.is_managed_identity(managed_identity): + raise ManagedIdentityError( + f"Incorrect managed_identity: {managed_identity}") self._managed_identity = managed_identity self._http_client = _ThrottledHttpClient( # This class only throttles excess token acquisition requests. @@ -237,12 +261,31 @@ def _get_instance(self): self.__instance = socket.getfqdn() # Moved from class definition to here return self.__instance - def acquire_token_for_client(self, *, resource): # We may support scope in the future + def acquire_token_for_client( + self, + *, + resource: str, # If/when we support scope, resource will become optional + claims_challenge: Optional[str] = None, + ): """Acquire token for the managed identity. The result will be automatically cached. Subsequent calls will automatically search from cache first. + :param resource: The resource for which the token is acquired. + + :param claims_challenge: + Optional. + It is a string representation of a JSON object + (which contains lists of claims being requested). + + The tenant admin may choose to revoke all Managed Identity tokens, + and then a *claims challenge* will be returned by the target resource, + as a `claims_challenge` directive in the `www-authenticate` header, + even if the app developer did not opt in for the "CP1" client capability. + Upon receiving a `claims_challenge`, MSAL will skip a token cache read, + and will attempt to acquire a new token. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -255,8 +298,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the access_token_from_cache = None client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") - if True: # Does not offer an "if not force_refresh" option, because - # there would be built-in token cache in the service side anyway + now = time.time() + if not claims_challenge: # Then attempt token cache search matches = self._token_cache.find( self._token_cache.CredentialType.ACCESS_TOKEN, target=[resource], @@ -267,7 +310,6 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the home_account_id=None, ), ) - now = time.time() for entry in matches: expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired @@ -277,6 +319,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the "access_token": entry["secret"], "token_type": entry.get("token_type", "Bearer"), "expires_in": int(expires_in), # OAuth2 specs defines it as int + self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE, } if "refresh_on" in entry: access_token_from_cache["refresh_on"] = int(entry["refresh_on"]) @@ -300,6 +343,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the )) if "refresh_in" in result: result["refresh_on"] = int(now + result["refresh_in"]) + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP if (result and "error" not in result) or (not access_token_from_cache): return result except: # The exact HTTP exception is transportation-layer dependent @@ -405,9 +449,9 @@ def _obtain_token(http_client, managed_identity, resource): return _obtain_token_on_azure_vm(http_client, managed_identity, resource) -def _adjust_param(params, managed_identity): +def _adjust_param(params, managed_identity, types_mapping=None): # Modify the params dict in place - id_name = ManagedIdentity._types_mapping.get( + id_name = (types_mapping or ManagedIdentity._types_mapping).get( managed_identity.get(ManagedIdentity.ID_TYPE)) if id_name: params[id_name] = managed_identity[ManagedIdentity.ID] @@ -421,7 +465,7 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource): } _adjust_param(params, managed_identity) resp = http_client.get( - "http://169.254.169.254/metadata/identity/oauth2/token", + os.getenv('MSAL_MANAGED_IDENTITY_ENDPOINT', 'http://169.254.169.254/metadata/identity/oauth2/token'), params=params, headers={"Metadata": "true"}, ) @@ -454,7 +498,12 @@ def _obtain_token_on_app_service( "api-version": "2019-08-01", "resource": resource, } - _adjust_param(params, managed_identity) + _adjust_param(params, managed_identity, types_mapping={ + ManagedIdentity.CLIENT_ID: "client_id", + ManagedIdentity.RESOURCE_ID: "mi_res_id", # App Service's resource id uses "mi_res_id" + ManagedIdentity.OBJECT_ID: "object_id", + }) + resp = http_client.get( endpoint, params=params, diff --git a/msal/token_cache.py b/msal/token_cache.py index e554e118..66be5c9f 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -43,6 +43,8 @@ def __init__(self): self._lock = threading.RLock() self._cache = {} self.key_makers = { + # Note: We have changed token key format before when ordering scopes; + # changing token key won't result in cache miss. self.CredentialType.REFRESH_TOKEN: lambda home_account_id=None, environment=None, client_id=None, target=None, **ignored_payload_from_a_real_token: @@ -56,14 +58,18 @@ def __init__(self): ]).lower(), self.CredentialType.ACCESS_TOKEN: lambda home_account_id=None, environment=None, client_id=None, - realm=None, target=None, **ignored_payload_from_a_real_token: - "-".join([ + realm=None, target=None, + # Note: New field(s) can be added here + #key_id=None, + **ignored_payload_from_a_real_token: + "-".join([ # Note: Could use a hash here to shorten key length home_account_id or "", environment or "", self.CredentialType.ACCESS_TOKEN, client_id or "", realm or "", target or "", + #key_id or "", # So ATs of different key_id can coexist ]).lower(), self.CredentialType.ID_TOKEN: lambda home_account_id=None, environment=None, client_id=None, @@ -124,7 +130,7 @@ def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool: target_set <= set(entry.get("target", "").split()) if target_set else True) - def search(self, credential_type, target=None, query=None): # O(n) generator + def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator """Returns a generator of matching entries. It is O(1) for AT hits, and O(n) for other types. @@ -150,21 +156,33 @@ def search(self, credential_type, target=None, query=None): # O(n) generator target_set = set(target) with self._lock: - # Since the target inside token cache key is (per schema) unsorted, - # there is no point to attempt an O(1) key-value search here. - # So we always do an O(n) in-memory search. + # O(n) search. The key is NOT used in search. + now = int(time.time() if now is None else now) + expired_access_tokens = [ + # Especially when/if we key ATs by ephemeral fields such as key_id, + # stale ATs keyed by an old key_id would stay forever. + # Here we collect them for their removal. + ] for entry in self._cache.get(credential_type, {}).values(): + if ( # Automatically delete expired access tokens + credential_type == self.CredentialType.ACCESS_TOKEN + and int(entry["expires_on"]) < now + ): + expired_access_tokens.append(entry) # Can't delete them within current for-loop + continue if (entry != preferred_result # Avoid yielding the same entry twice and self._is_matching(entry, query, target_set=target_set) ): yield entry + for at in expired_access_tokens: + self.remove_at(at) - def find(self, credential_type, target=None, query=None): + def find(self, credential_type, target=None, query=None, *, now=None): """Equivalent to list(search(...)).""" warnings.warn( "Use list(search(...)) instead to explicitly get a list.", DeprecationWarning) - return list(self.search(credential_type, target=target, query=query)) + return list(self.search(credential_type, target=target, query=query, now=now)) def add(self, event, now=None): """Handle a token obtaining event, and add tokens into cache.""" @@ -249,8 +267,11 @@ def __add(self, event, now=None): "expires_on": str(now + expires_in), # Same here "extended_expires_on": str(now + ext_expires_in) # Same here } - if data.get("key_id"): # It happens in SSH-cert or POP scenario - at["key_id"] = data.get("key_id") + at.update({k: data[k] for k in data if k in { + # Also store extra data which we explicitly allow + # So that we won't accidentally store a user's password etc. + "key_id", # It happens in SSH-cert or POP scenario + }}) if "refresh_in" in response: refresh_in = response["refresh_in"] # It is an integer at["refresh_on"] = str(now + refresh_in) # Schema wants a string diff --git a/sample/interactive_sample.py b/sample/interactive_sample.py index 3a361fcf..8c3f2df9 100644 --- a/sample/interactive_sample.py +++ b/sample/interactive_sample.py @@ -46,7 +46,8 @@ authority=os.getenv('AUTHORITY'), # For Entra ID or External ID oidc_authority=os.getenv('OIDC_AUTHORITY'), # For External ID with custom domain #enable_broker_on_windows=True, # Opted in. You will be guided to meet the prerequisites, if your app hasn't already - # See also: https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-acquire-token-wam#wam-value-proposition + #enable_broker_on_mac=True, # Opted in. You will be guided to meet the prerequisites, if your app hasn't already + token_cache=global_token_cache, # Let this app (re)use an existing token cache. # If absent, ClientApplication will create its own empty token cache ) diff --git a/setup.cfg b/setup.cfg index 490e3ab8..33ec3f06 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,9 +60,11 @@ broker = # The broker is defined as optional dependency, # so that downstream apps can opt in. The opt-in is needed, partially because # most existing MSAL Python apps do not have the redirect_uri needed by broker. - # MSAL Python uses a subset of API from PyMsalRuntime 0.13.0+, - # but we still bump the lower bound to 0.13.2+ for its important bugfix (https://github.com/AzureAD/microsoft-authentication-library-for-cpp/pull/3244) - pymsalruntime>=0.13.2,<0.17; python_version>='3.6' and platform_system=='Windows' + # + # We need pymsalruntime.CallbackData introduced in PyMsalRuntime 0.14 + pymsalruntime>=0.14,<0.18; python_version>='3.6' and platform_system=='Windows' + # On Mac, PyMsalRuntime 0.17+ is expected to support SSH cert and ROPC + pymsalruntime>=0.17,<0.18; python_version>='3.8' and platform_system=='Darwin' [options.packages.find] exclude = diff --git a/tests/broker-test.py b/tests/broker-test.py index 216d5256..cdcc4817 100644 --- a/tests/broker-test.py +++ b/tests/broker-test.py @@ -4,6 +4,16 @@ Each time a new PyMsalRuntime is going to be released, we can use this script to test it with a given version of MSAL Python. + +1. If you are on a modern Windows device, broker WAM is already built-in; + If you are on a mac device, install CP (Company Portal), login an account in CP and finish the MDM process. +2. For installing MSAL Python from its latest `dev` branch: + `pip install --force-reinstall "git+https://github.com/AzureAD/microsoft-authentication-library-for-python.git[broker]"` +3. (Optional) A proper version of `PyMsalRuntime` has already been installed by the previous command. + But if you want to test a specific version of `PyMsalRuntime`, + you shall manually install that version now. +4. Run this test by `python broker-test.py` and make sure all the tests passed. + """ import msal import getpass @@ -28,6 +38,7 @@ pca = msal.PublicClientApplication( _AZURE_CLI, authority="https://login.microsoftonline.com/organizations", + enable_broker_on_mac=True, enable_broker_on_windows=True) def interactive_and_silent(scopes, auth_scheme, data, expected_token_type): diff --git a/tests/smoke-test.md b/tests/smoke-test.md new file mode 100644 index 00000000..a0d35daf --- /dev/null +++ b/tests/smoke-test.md @@ -0,0 +1,68 @@ +# How to Smoke Test MSAL Python + +The experimental `python -m msal` usage is designed to be an interactive tool, +which can impersonate arbitrary apps and test most of the MSAL Python APIs. +Note that MSAL Python API's behavior is modeled after OIDC behavior in browser, +which are not exactly the same as the broker API's behavior, +despite that the two sets of API happen to have similar names. + +Tokens acquired during the tests will be cached by MSAL Python. +MSAL Python uses an in-memory token cache by default. +This test tool, however, saves a token cache snapshot on disk upon each exit, +and you may choose to reuse it or start afresh during start up. + +Typical test cases are listed below. + +1. The tool starts with an empty token cache. + In this state, acquire_token_silent() shall always return empty result. + +2. When testing with broker, apps would need to register a certain redirect_uri + for the test cases below to work. + We will also test an app without the required redirect_uri registration, + MSAL Python shall return a meaningful error message on what URIs to register. + +3. Interactive acquire_token_interactive() shall get a token. In particular, + + * The prompt=none option shall succeed when there is a default account, + or error out otherwise. + * The prompt=select_account option shall always prompt with an account picker. + * The prompt=absent option shall prompt an account picker UI + if there are multiple accounts available in browser + and none of them is considered a default account. + In such a case, an optional login_hint=`one_of_the_account@contoso.com` + shall bypass the account picker. + + With a broker, the behavior shall largely match the browser behavior, + unless stated otherwise below. + + * Broker (PyMsalRuntime) on Mac does not support silent signin, + so the prompt=absent will also always prompt. + +4. ROPC (Resource Owner Password Credential, a.k.a. the username password flow). + The acquire_token_by_username_password() is supported by broker on Windows. + As of Oct 2023, it is not yet supported by broker on Mac, + so it will fall back to non-broker behavior. + +5. After step 3 or 4, the acquire_token_silently() shall return a token fast, + because that is the same token returned by step 3 or 4, cached in MSAL Python. + We shall also retest this with the force_refresh=True, + a new token shall be obtained, + typically slower than a token served from MSAL Python's token cache. + +6. POP token. + POP token is supported via broker. + This tool test the POP token by using a hardcoded Signed Http Request (SHR). + A test is successful if the POP test function return a token with type as POP. + +7. SSH Cert. + The interactive test and silent test shall behave similarly to + their non ssh-cert counterparts, only the `token_type` would be different. + +8. Test the remove_account() API. It shall always be successful. + This effectively signs out an account from MSAL Python, + we can confirm that by running acquire_token_silent() + and see that account was gone. + + The remove_account() shall also sign out from broker (if broker was enabled), + it does not sign out account from browser (even when browser was used). + diff --git a/tests/test_application.py b/tests/test_application.py index 71dc16ea..e565e105 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,11 +1,16 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. +import json +import logging import sys import time -from msal.application import * -from msal.application import _str2bytes +from unittest.mock import patch, Mock import msal -from msal.application import _merge_claims_challenge_and_capabilities +from msal.application import ( + extract_certs, + ClientApplication, PublicClientApplication, ConfidentialClientApplication, + _str2bytes, _merge_claims_challenge_and_capabilities, +) from tests import unittest from tests.test_token_cache import build_id_token, build_response from tests.http_client import MinimalHttpClient, MinimalResponse @@ -335,6 +340,7 @@ class TestApplicationForRefreshInBehaviors(unittest.TestCase): account = {"home_account_id": "{}.{}".format(uid, utid)} rt = "this is a rt" client_id = "my_app" + soon = 60 # application.py considers tokens within 5 minutes as expired @classmethod def setUpClass(cls): # Initialization at runtime, not interpret-time @@ -409,7 +415,8 @@ def mock_post(url, headers=None, *args, **kwargs): def test_expired_token_and_unavailable_aad_should_return_error(self): # a.k.a. Attempt refresh expired token when AAD unavailable - self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + self.populate_cache( + access_token="expired at", expires_in=self.soon, refresh_in=-900) error = "something went wrong" def mock_post(url, headers=None, *args, **kwargs): self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) @@ -420,7 +427,8 @@ def mock_post(url, headers=None, *args, **kwargs): def test_expired_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt refresh expired token when AAD available - self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + self.populate_cache( + access_token="expired at", expires_in=self.soon, refresh_in=-900) new_access_token = "new AT" new_refresh_in = 123 def mock_post(url, headers=None, *args, **kwargs): @@ -722,3 +730,129 @@ def test_client_id_should_be_a_valid_scope(self): self._test_client_id_should_be_a_valid_scope("client_id", []) self._test_client_id_should_be_a_valid_scope("client_id", ["foo"]) + +@patch("sys.platform", new="darwin") # Pretend running on Mac. +@patch("msal.authority.tenant_discovery", new=Mock(return_value={ + "authorization_endpoint": "https://contoso.com/placeholder", + "token_endpoint": "https://contoso.com/placeholder", + })) +class TestMsalBehaviorWithoutPyMsalRuntimeOrBroker(unittest.TestCase): + + @patch("msal.application._init_broker", new=Mock(side_effect=ImportError( + "PyMsalRuntime not installed" + ))) + def test_broker_should_be_disabled_by_default(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://login.microsoftonline.com/common", + ) + self.assertFalse(app._enable_broker) + + @patch("msal.application._init_broker", new=Mock(side_effect=ImportError( + "PyMsalRuntime not installed" + ))) + def test_opt_in_should_error_out_when_pymsalruntime_not_installed(self): + """Because it is actionable to app developer to add dependency declaration""" + with self.assertRaises(ImportError): + app = msal.PublicClientApplication( + "client_id", + authority="https://login.microsoftonline.com/common", + enable_broker_on_mac=True, + ) + + @patch("msal.application._init_broker", new=Mock(side_effect=RuntimeError( + "PyMsalRuntime raises RuntimeError when broker initialization failed" + ))) + def test_should_fallback_when_pymsalruntime_failed_to_initialize_broker(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://login.microsoftonline.com/common", + enable_broker_on_mac=True, + ) + self.assertFalse(app._enable_broker) + + +@patch("sys.platform", new="darwin") # Pretend running on Mac. +@patch("msal.authority.tenant_discovery", new=Mock(return_value={ + "authorization_endpoint": "https://contoso.com/placeholder", + "token_endpoint": "https://contoso.com/placeholder", + })) +@patch("msal.application._init_broker", new=Mock()) # Pretend pymsalruntime installed and working +class TestBrokerFallbackWithDifferentAuthorities(unittest.TestCase): + + def test_broker_should_be_disabled_by_default(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://login.microsoftonline.com/common", + ) + self.assertFalse(app._enable_broker) + + def test_broker_should_be_enabled_when_opted_in(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://login.microsoftonline.com/common", + enable_broker_on_mac=True, + ) + self.assertTrue(app._enable_broker) + + def test_should_fallback_to_non_broker_when_using_adfs(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://contoso.com/adfs", + #instance_discovery=False, # Automatically skipped when detected ADFS + enable_broker_on_mac=True, + ) + self.assertFalse(app._enable_broker) + + def test_should_fallback_to_non_broker_when_using_b2c(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://contoso.b2clogin.com/contoso/policy", + #instance_discovery=False, # Automatically skipped when detected B2C + enable_broker_on_mac=True, + ) + self.assertFalse(app._enable_broker) + + def test_should_use_broker_when_disabling_instance_discovery(self): + app = msal.PublicClientApplication( + "client_id", + authority="https://contoso.com/path", + instance_discovery=False, # Need this for a generic authority url + enable_broker_on_mac=True, + ) + # TODO: Shall we bypass broker when opted out of instance discovery? + self.assertTrue(app._enable_broker) # Current implementation enables broker + + def test_should_fallback_to_non_broker_when_using_oidc_authority(self): + app = msal.PublicClientApplication( + "client_id", + oidc_authority="https://contoso.com/path", + enable_broker_on_mac=True, + ) + self.assertFalse(app._enable_broker) + + def test_app_did_not_register_redirect_uri_should_error_out(self): + """Because it is actionable to app developer to add redirect URI""" + app = msal.PublicClientApplication( + "client_id", + authority="https://login.microsoftonline.com/common", + enable_broker_on_mac=True, + ) + self.assertTrue(app._enable_broker) + with patch.object( + # Note: We tried @patch("msal.broker.foo", ...) but it ended up with + # "module msal does not have attribute broker" + app, "_acquire_token_interactive_via_broker", return_value={ + "error": "broker_error", + "error_description": + "(pii). " # pymsalruntime no longer surfaces AADSTS error, + # So MSAL Python can't raise RedirectUriError. + "Status: Response_Status.Status_ApiContractViolation, " + "Error code: 3399614473, Tag 557973642", + }): + result = app.acquire_token_interactive( + ["scope"], + parent_window_handle=app.CONSOLE_WINDOW_HANDLE, + ) + self.assertEqual(result.get("error"), "broker_error") + diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 68ad2af7..a0796547 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -16,6 +16,7 @@ import json import time import unittest +from urllib.parse import urlparse, parse_qs import sys try: from unittest.mock import patch, ANY @@ -171,6 +172,7 @@ def _build_app(cls, client_id, client_credential=None, authority="https://login.microsoftonline.com/common", + oidc_authority=None, scopes=["https://graph.microsoft.com/.default"], # Microsoft Graph http_client=None, azure_region=None, @@ -180,6 +182,7 @@ def _build_app(cls, client_id, client_credential=client_credential, authority=authority, + oidc_authority=oidc_authority, azure_region=azure_region, http_client=http_client or MinimalHttpClient(), ) @@ -193,20 +196,24 @@ def _build_app(cls, return msal.PublicClientApplication( client_id, authority=authority, + oidc_authority=oidc_authority, http_client=http_client or MinimalHttpClient(), enable_broker_on_windows=broker_available, + enable_broker_on_mac=broker_available, ) def _test_username_password(self, authority=None, client_id=None, username=None, password=None, scope=None, + oidc_authority=None, client_secret=None, # Since MSAL 1.11, confidential client has ROPC too azure_region=None, http_client=None, auth_scheme=None, **ignored): - assert authority and client_id and username and password and scope + assert client_id and username and password and scope and ( + authority or oidc_authority) self.app = self._build_app( - client_id, authority=authority, + client_id, authority=authority, oidc_authority=oidc_authority, http_client=http_client, azure_region=azure_region, # Regional endpoint does not support ROPC. # Here we just use it to test a regional app won't break ROPC. @@ -227,9 +234,14 @@ def _test_username_password(self, os.getenv("TRAVIS"), # It is set when running on TravisCI or Github Actions "Although it is doable, we still choose to skip device flow to save time") def _test_device_flow( - self, client_id=None, authority=None, scope=None, **ignored): - assert client_id and authority and scope - self.app = self._build_app(client_id, authority=authority) + self, + *, + client_id=None, authority=None, oidc_authority=None, scope=None, + **ignored + ): + assert client_id and scope and (authority or oidc_authority) + self.app = self._build_app( + client_id, authority=authority, oidc_authority=oidc_authority) flow = self.app.initiate_device_flow(scopes=scope) assert "user_code" in flow, "DF does not seem to be provisioned: %s".format( json.dumps(flow, indent=4)) @@ -253,7 +265,8 @@ def _test_device_flow( @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") def _test_acquire_token_interactive( - self, client_id=None, authority=None, scope=None, port=None, + self, *, client_id=None, authority=None, scope=None, port=None, + oidc_authority=None, username=None, lab_name=None, username_uri="", # Unnecessary if you provided username and lab_name data=None, # Needed by ssh-cert feature @@ -261,8 +274,9 @@ def _test_acquire_token_interactive( enable_msa_passthrough=None, auth_scheme=None, **ignored): - assert client_id and authority and scope - self.app = self._build_app(client_id, authority=authority) + assert client_id and scope and (authority or oidc_authority) + self.app = self._build_app( + client_id, authority=authority, oidc_authority=oidc_authority) logger.info(_get_hint( # Useful when testing broker which shows no welcome_template username=username, lab_name=lab_name, username_uri=username_uri)) result = self.app.acquire_token_interactive( @@ -680,10 +694,13 @@ def _test_acquire_token_obo(self, config_pca, config_cca, def _test_acquire_token_by_client_secret( self, client_id=None, client_secret=None, authority=None, scope=None, + oidc_authority=None, **ignored): - assert client_id and client_secret and authority and scope + assert client_id and client_secret and scope and ( + authority or oidc_authority) self.app = msal.ConfidentialClientApplication( client_id, client_credential=client_secret, authority=authority, + oidc_authority=oidc_authority, http_client=MinimalHttpClient()) result = self.app.acquire_token_for_client(scope) self.assertIsNotNone(result.get("access_token"), "Got %s instead" % result) @@ -1004,14 +1021,18 @@ class CiamTestCase(LabBasedTestCase): @classmethod def setUpClass(cls): super(CiamTestCase, cls).setUpClass() - cls.user = cls.get_lab_user(federationProvider="ciam") + cls.user = cls.get_lab_user( + #federationProvider="ciam", # This line would return ciam2 tenant + federationProvider="ciamcud", signinAudience="AzureAdMyOrg", # ciam6 + ) # FYI: Only single- or multi-tenant CIAM app can have other-than-OIDC # delegated permissions on Microsoft Graph. cls.app_config = cls.get_lab_app_object(cls.user["client_id"]) def test_ciam_acquire_token_interactive(self): self._test_acquire_token_interactive( - authority=self.app_config["authority"], + authority=self.app_config.get("authority"), + oidc_authority=self.app_config.get("oidc_authority"), client_id=self.app_config["appId"], scope=self.app_config["scopes"], username=self.user["username"], @@ -1019,13 +1040,18 @@ def test_ciam_acquire_token_interactive(self): ) def test_ciam_acquire_token_for_client(self): + raw_url = self.app_config["clientSecret"] + secret_url = urlparse(raw_url) + if secret_url.query: # Ciam2 era has a query param Secret=name + secret_name = parse_qs(secret_url.query)["Secret"][0] + else: # Ciam6 era has a URL path that ends with the secret name + secret_name = secret_url.path.split("/")[-1] + logger.info('Detected secret name "%s" from "%s"', secret_name, raw_url) self._test_acquire_token_by_client_secret( client_id=self.app_config["appId"], - client_secret=self.get_lab_user_secret( - self.app_config["clientSecret"].split("=")[-1]), - authority=self.app_config["authority"], - #scope=["{}/.default".format(self.app_config["appId"])], # AADSTS500207: The account type can't be used for the resource you're trying to access. - #scope=["api://{}/.default".format(self.app_config["appId"])], # AADSTS500011: The resource principal named api://ced781e7-bdb0-4c99-855c-d3bacddea88a was not found in the tenant named MSIDLABCIAM2. This can happen if the application has not been installed by the administrator of the tenant or consented to by any user in the tenant. You might have sent your authentication request to the wrong tenant. + client_secret=self.get_lab_user_secret(secret_name), + authority=self.app_config.get("authority"), + oidc_authority=self.app_config.get("oidc_authority"), scope=self.app_config["scopes"], # It shall ends with "/.default" ) @@ -1038,21 +1064,35 @@ def test_ciam_acquire_token_by_ropc(self): # and enabling "Allow public client flows". # Otherwise it would hit AADSTS7000218. self._test_username_password( - authority=self.app_config["authority"], + authority=self.app_config.get("authority"), + oidc_authority=self.app_config.get("oidc_authority"), client_id=self.app_config["appId"], username=self.user["username"], password=self.get_lab_user_secret(self.user["lab_name"]), scope=self.app_config["scopes"], ) + @unittest.skip("""As of Aug 2024, in both ciam2 and ciam6, sign-in fails with +AADSTS500208: The domain is not a valid login domain for the account type.""") def test_ciam_device_flow(self): self._test_device_flow( - authority=self.app_config["authority"], + authority=self.app_config.get("authority"), + oidc_authority=self.app_config.get("oidc_authority"), client_id=self.app_config["appId"], scope=self.app_config["scopes"], ) +class CiamCudTestCase(CiamTestCase): + @classmethod + def setUpClass(cls): + super(CiamCudTestCase, cls).setUpClass() + cls.app_config["authority"] = None + cls.app_config["oidc_authority"] = ( + # Derived from https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.63.0/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/CiamIntegrationTests.cs#L156 + "https://login.msidlabsciam.com/fe362aec-5d43-45d1-b730-9755e60dc3b9/v2.0") + + class WorldWideRegionalEndpointTestCase(LabBasedTestCase): region = "westus" timeout = 2 # Short timeout makes this test case responsive on non-VM @@ -1090,11 +1130,23 @@ def _test_acquire_token_for_client(self, configured_region, expected_region): def test_acquire_token_for_client_should_hit_global_endpoint_by_default(self): self._test_acquire_token_for_client(None, None) - def test_acquire_token_for_client_should_ignore_env_var_by_default(self): + def test_acquire_token_for_client_should_ignore_env_var_region_name_by_default(self): os.environ["REGION_NAME"] = "eastus" self._test_acquire_token_for_client(None, None) del os.environ["REGION_NAME"] + @patch.dict(os.environ, {"MSAL_FORCE_REGION": "eastus"}) + def test_acquire_token_for_client_should_use_env_var_msal_force_region_by_default(self): + self._test_acquire_token_for_client(None, "eastus") + + @patch.dict(os.environ, {"MSAL_FORCE_REGION": "eastus"}) + def test_acquire_token_for_client_should_prefer_the_explicit_region(self): + self._test_acquire_token_for_client("westus", "westus") + + @patch.dict(os.environ, {"MSAL_FORCE_REGION": "eastus"}) + def test_acquire_token_for_client_should_allow_opt_out_env_var_msal_force_region(self): + self._test_acquire_token_for_client(False, None) + def test_acquire_token_for_client_should_use_a_specified_region(self): self._test_acquire_token_for_client("westus", "westus") diff --git a/tests/test_mi.py b/tests/test_mi.py index d3a83a0c..c5a99ae3 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -61,6 +61,14 @@ def setUp(self): http_client=requests.Session(), ) + def test_error_out_on_invalid_input(self): + with self.assertRaises(ManagedIdentityError): + ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session()) + with self.assertRaises(ManagedIdentityError): + ManagedIdentityClient( + {"ManagedIdentityIdType": "undefined", "Id": "foo"}, + http_client=requests.Session()) + def assertCacheStatus(self, app): cache = app._token_cache._cache self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT") @@ -82,20 +90,17 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): self.assertTrue( is_subdict_of(expected_result, result), # We will test refresh_on later "Should obtain a token response") + self.assertTrue(result["token_source"], "identity_provider") self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in") if expires_in >= 7200: expected_refresh_on = int(time.time() + expires_in / 2) self.assertTrue( expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1, "Should have a refresh_on time around the middle of the token's life") - self.assertEqual( - result["access_token"], - app.acquire_token_for_client(resource=resource).get("access_token"), - "Should hit the same token from cache") - - self.assertCacheStatus(app) result = app.acquire_token_for_client(resource=resource) + self.assertCacheStatus(app) + self.assertEqual("cache", result["token_source"], "Should hit cache") self.assertEqual( call_count, mocked_http.call_count, "No new call to the mocked http should be made for a cache hit") @@ -110,6 +115,9 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on, "Should have a refresh_on time around the middle of the token's life") + result = app.acquire_token_for_client(resource=resource, claims_challenge="foo") + self.assertEqual("identity_provider", result["token_source"], "Should miss cache") + class VmTestCase(ClientTestCase): @@ -131,6 +139,22 @@ def test_vm_error_should_be_returned_as_is(self): json.loads(raw_error), self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_vm_resource_id_parameter_should_be_msi_res_id(self): + app = ManagedIdentityClient( + {"ManagedIdentityIdType": "ResourceId", "Id": "1234"}, + http_client=requests.Session(), + ) + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": 3600, "resource": "R"}', + )) as mocked_method: + app.acquire_token_for_client(resource="R") + mocked_method.assert_called_with( + 'http://169.254.169.254/metadata/identity/oauth2/token', + params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'}, + headers={'Metadata': 'true'}, + ) + @patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) class AppServiceTestCase(ClientTestCase): @@ -156,6 +180,22 @@ def test_app_service_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_app_service_resource_id_parameter_should_be_mi_res_id(self): + app = ManagedIdentityClient( + {"ManagedIdentityIdType": "ResourceId", "Id": "1234"}, + http_client=requests.Session(), + ) + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": 12345, "resource": "R"}', + )) as mocked_method: + app.acquire_token_for_client(resource="R") + mocked_method.assert_called_with( + 'http://localhost', + params={'api-version': '2019-08-01', 'resource': 'R', 'mi_res_id': '1234'}, + headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'}, + ) + @patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) class MachineLearningTestCase(ClientTestCase): @@ -241,6 +281,9 @@ class ArcTestCase(ClientTestCase): "WWW-Authenticate": "Basic realm=/tmp/foo", }) + def test_error_out_on_invalid_input(self, mocked_stat): + return super(ArcTestCase, self).test_error_out_on_invalid_input() + def test_happy_path(self, mocked_stat): expires_in = 1234 with patch.object(self.app._http_client, "get", side_effect=[ @@ -249,7 +292,8 @@ def test_happy_path(self, mocked_stat): status_code=200, text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in, ), - ]) as mocked_method: + ] * 2, # Duplicate a pair of mocks for _test_happy_path()'s CAE check + ) as mocked_method: try: self._test_happy_path(self.app, mocked_method, expires_in) mocked_stat.assert_called_with(os.path.join( diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 4e301fa3..494d6daf 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -3,7 +3,7 @@ import json import time -from msal.token_cache import * +from msal.token_cache import TokenCache, SerializableTokenCache from tests import unittest @@ -51,11 +51,14 @@ class TokenCacheTestCase(unittest.TestCase): def setUp(self): self.cache = TokenCache() + self.at_key_maker = self.cache.key_makers[ + TokenCache.CredentialType.ACCESS_TOKEN] def testAddByAad(self): client_id = "my_client_id" id_token = build_id_token( oid="object1234", preferred_username="John Doe", aud=client_id) + now = 1000 self.cache.add({ "client_id": client_id, "scope": ["s2", "s1", "s3"], # Not in particular order @@ -64,7 +67,7 @@ def testAddByAad(self): uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), - }, now=1000) + }, now=now) access_token_entry = { 'cached_at': "1000", 'client_id': 'my_client_id', @@ -78,14 +81,11 @@ def testAddByAad(self): 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', } - self.assertEqual( - access_token_entry, - self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3') - ) + self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( + self.at_key_maker(**access_token_entry))) self.assertIn( access_token_entry, - self.cache.find(self.cache.CredentialType.ACCESS_TOKEN), + self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now), "find(..., query=None) should not crash, even though MSAL does not use it") self.assertEqual( { @@ -144,8 +144,7 @@ def testAddByAdfs(self): expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), }, now=1000) - self.assertEqual( - { + access_token_entry = { 'cached_at': "1000", 'client_id': 'my_client_id', 'credential_type': 'AccessToken', @@ -157,10 +156,9 @@ def testAddByAdfs(self): 'secret': 'an access token', 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', - }, - self.cache._cache["AccessToken"].get( - 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3') - ) + } + self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( + self.at_key_maker(**access_token_entry))) self.assertEqual( { 'client_id': 'my_client_id', @@ -206,37 +204,67 @@ def testAddByAdfs(self): "appmetadata-fs.msidlab8.com-my_client_id") ) - def test_key_id_is_also_recorded(self): - my_key_id = "some_key_id_123" + def assertFoundAccessToken(self, *, scopes, query, data=None, now=None): + cached_at = None + for cached_at in self.cache.search( + TokenCache.CredentialType.ACCESS_TOKEN, + target=scopes, query=query, now=now, + ): + for k, v in (data or {}).items(): # The extra data, if any + self.assertEqual(cached_at.get(k), v, f"AT should contain {k}={v}") + self.assertTrue(cached_at, "AT should be cached and searchable") + return cached_at + + def _test_data_should_be_saved_and_searchable_in_access_token(self, data): + scopes = ["s2", "s1", "s3"] # Not in particular order + now = 1000 self.cache.add({ - "data": {"key_id": my_key_id}, + "data": data, "client_id": "my_client_id", - "scope": ["s2", "s1", "s3"], # Not in particular order + "scope": scopes, "token_endpoint": "https://login.example.com/contoso/v2/token", "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", refresh_token="a refresh token"), - }, now=1000) - cached_key_id = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', - {}).get("key_id") - self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + }, now=now) + self.assertFoundAccessToken(scopes=scopes, data=data, now=now, query=dict( + data, # Also use the extra data as a query criteria + client_id="my_client_id", + environment="login.example.com", + realm="contoso", + home_account_id="uid.utid", + )) + + def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self): + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"}) + + def test_access_tokens_with_different_key_id(self): + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"}) + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"}) + self.assertEqual( + len(self.cache._cache["AccessToken"]), + 1, """Historically, tokens are not keyed by key_id, +so a new token overwrites the old one, and we would end up with 1 token in cache""") def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep. + scopes = ["s2", "s1", "s3"] # Not in particular order self.cache.add({ "client_id": "my_client_id", - "scope": ["s2", "s1", "s3"], # Not in particular order + "scope": scopes, "token_endpoint": "https://login.example.com/contoso/v2/token", "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, refresh_in=1800, access_token="an access token", ), #refresh_token="a refresh token"), }, now=1000) - refresh_on = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', - {}).get("refresh_on") - self.assertEqual("2800", refresh_on, "Should save refresh_on") + at = self.assertFoundAccessToken(scopes=scopes, query=dict( + client_id="my_client_id", + environment="login.example.com", + realm="contoso", + home_account_id="uid.utid", + )) + self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on") def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): sample = { @@ -258,7 +286,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): ) -class SerializableTokenCacheTestCase(TokenCacheTestCase): +class SerializableTokenCacheTestCase(unittest.TestCase): # Run all inherited test methods, and have extra check in tearDown() def setUp(self):