Skip to content

Commit

Permalink
Acquire SSH cert from Cloud Shell IMDS
Browse files Browse the repository at this point in the history
Cloud Shell Detection

PoC: Silent flow utilizes Cloud Shell IMDS

Introduce get_accounts(username=msal.CURRENT_USER)

A reasonable-effort to convert scope to resource

Replace get_accounts(username=msal.CURRENT_USER) by acquire_token_interactive(..., prompt="none")

Detect unsupported Portal so that AzCLI could fallback
  • Loading branch information
rayluo committed May 19, 2022
1 parent c7e81ba commit 292e28b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 6 deletions.
33 changes: 30 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
from .cloudshell import _is_running_in_cloud_shell


# The __init__.py will import this. Not the other way around.
__version__ = "1.17.0" # When releasing, also check and bump our dependencies's versions if needed

logger = logging.getLogger(__name__)

_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"

def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
Expand Down Expand Up @@ -986,6 +987,10 @@ def get_accounts(self, username=None):
return accounts

def _find_msal_accounts(self, environment):
interested_authority_types = [
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS]
if _is_running_in_cloud_shell():
interested_authority_types.append(_AUTHORITY_TYPE_CLOUDSHELL)
grouped_accounts = {
a.get("home_account_id"): # Grouped by home tenant's id
{ # These are minimal amount of non-tenant-specific account info
Expand All @@ -1001,8 +1006,7 @@ def _find_msal_accounts(self, environment):
for a in self.token_cache.find(
TokenCache.CredentialType.ACCOUNT,
query={"environment": environment})
if a["authority_type"] in (
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)
if a["authority_type"] in interested_authority_types
}
return list(grouped_accounts.values())

Expand Down Expand Up @@ -1062,6 +1066,21 @@ def _forget_me(self, home_account):
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
self.token_cache.remove_account(a)

def _acquire_token_by_cloud_shell(self, scopes, data=None):
from .cloudshell import _obtain_token
response = _obtain_token(
self.http_client, scopes, client_id=self.client_id, data=data)
if "error" not in response:
self.token_cache.add(dict(
client_id=self.client_id,
scope=response["scope"].split() if "scope" in response else scopes,
token_endpoint=self.authority.token_endpoint,
response=response.copy(),
data=data or {},
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
))
return response

def acquire_token_silent(
self,
scopes, # type: List[str]
Expand Down Expand Up @@ -1195,6 +1214,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
authority, # This can be different than self.authority
force_refresh=False, # type: Optional[boolean]
claims_challenge=None,
correlation_id=None,
**kwargs):
access_token_from_cache = None
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
Expand Down Expand Up @@ -1233,9 +1253,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
assert refresh_reason, "It should have been established at this point"
try:
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
return self._acquire_token_by_cloud_shell(
scopes, data=kwargs.get("data"))
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, self._decorate_scope(scopes), account,
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
correlation_id=correlation_id,
**kwargs))
if (result and "error" not in result) or (not access_token_from_cache):
return result
Expand Down Expand Up @@ -1574,6 +1598,9 @@ def acquire_token_interactive(
- A dict containing an "error" key, when token refresh failed.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
if _is_running_in_cloud_shell() and prompt == "none":
return self._acquire_token_by_cloud_shell(
scopes, data=kwargs.pop("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
telemetry_context = self._build_telemetry_context(
Expand Down
122 changes: 122 additions & 0 deletions msal/cloudshell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation.
# All rights reserved.
#
# This code is licensed under the MIT License.

"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper"""
import base64
import json
import logging
import os
import time
try: # Python 2
from urlparse import urlparse
except: # Python 3
from urllib.parse import urlparse
from .oauth2cli.oidc import decode_part


logger = logging.getLogger(__name__)


def _is_running_in_cloud_shell():
return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell")


def _scope_to_resource(scope): # This is an experimental reasonable-effort approach
cloud_shell_supported_audiences = [
"https://analysis.windows.net/powerbi/api", # Came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json
"https://pas.windows.net/CheckMyAccess/Linux/.default", # Cloud Shell accepts it as-is
]
for a in cloud_shell_supported_audiences:
if scope.startswith(a):
return a
u = urlparse(scope)
if u.scheme:
return "{}://{}".format(u.scheme, u.netloc)
return scope # There is no much else we can do here


def _obtain_token(http_client, scopes, client_id=None, data=None):
resp = http_client.post(
"http://localhost:50342/oauth2/token",
data=dict(
data or {},
resource=" ".join(map(_scope_to_resource, scopes))),
headers={"Metadata": "true"},
)
if resp.status_code >= 300:
logger.debug("Cloud Shell IMDS error: %s", resp.text)
cs_error = json.loads(resp.text).get("error", {})
return {k: v for k, v in {
"error": cs_error.get("code"),
"error_description": cs_error.get("message"),
}.items() if v}
imds_payload = json.loads(resp.text)
BEARER = "Bearer"
oauth2_response = {
"access_token": imds_payload["access_token"],
"expires_in": int(imds_payload["expires_in"]),
"token_type": imds_payload.get("token_type", BEARER),
}
expected_token_type = (data or {}).get("token_type", BEARER)
if oauth2_response["token_type"] != expected_token_type:
return { # Generate a normal error (rather than an intrusive exception)
"error": "broker_error",
"error_description": "token_type {} is not supported by this version of Azure Portal".format(
expected_token_type),
}
parts = imds_payload["access_token"].split(".")

# The following default values are useful in SSH Cert scenario
client_info = { # Default value, in case the real value will be unavailable
"uid": "user",
"utid": "cloudshell",
}
now = time.time()
preferred_username = "currentuser@cloudshell"
oauth2_response["id_token_claims"] = { # First 5 claims are required per OIDC
"iss": "cloudshell",
"sub": "user",
"aud": client_id,
"exp": now + 3600,
"iat": now,
"preferred_username": preferred_username, # Useful as MSAL account's username
}

if len(parts) == 3: # Probably a JWT. Use it to derive client_info and id token.
try:
# Data defined in https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens#payload-claims
jwt_payload = json.loads(decode_part(parts[1]))
client_info = {
# Mimic a real home_account_id,
# so that this pseudo account and a real account would interop.
"uid": jwt_payload.get("oid", "user"),
"utid": jwt_payload.get("tid", "cloudshell"),
}
oauth2_response["id_token_claims"] = {
"iss": jwt_payload["iss"],
"sub": jwt_payload["sub"], # Could use oid instead
"aud": client_id,
"exp": jwt_payload["exp"],
"iat": jwt_payload["iat"],
"preferred_username": jwt_payload.get("preferred_username") # V2
or jwt_payload.get("unique_name") # V1
or preferred_username,
}
except ValueError:
logger.debug("Unable to decode jwt payload: %s", parts[1])
oauth2_response["client_info"] = base64.b64encode(
# Mimic a client_info, so that MSAL would create an account
json.dumps(client_info).encode("utf-8")).decode("utf-8")
oauth2_response["id_token_claims"]["tid"] = client_info["utid"] # TBD

## Note: Decided to not surface resource back as scope,
## because they would cause the downstream OAuth2 code path to
## cache the token with a different scope and won't hit them later.
#if imds_payload.get("resource"):
# oauth2_response["scope"] = imds_payload["resource"]
if imds_payload.get("refresh_token"):
oauth2_response["refresh_token"] = imds_payload["refresh_token"]
return oauth2_response

9 changes: 6 additions & 3 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def wipe(dictionary, sensitive_fields): # Masks sensitive info
return self.__add(event, now=now)
finally:
wipe(event.get("response", {}), ( # These claims were useful during __add()
"id_token_claims", # Provided by broker
"access_token", "refresh_token", "id_token", "username"))
wipe(event, ["username"]) # Needed for federated ROPC
logger.debug("event=%s", json.dumps(
Expand Down Expand Up @@ -150,7 +151,8 @@ def __add(self, event, now=None):
id_token = response.get("id_token")
id_token_claims = (
decode_id_token(id_token, client_id=event["client_id"])
if id_token else {})
if id_token
else response.get("id_token_claims", {})) # Broker would provide id_token_claims
client_info, home_account_id = self.__parse_account(response, id_token_claims)

target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
Expand Down Expand Up @@ -195,9 +197,10 @@ def __add(self, event, now=None):
or data.get("username") # Falls back to ROPC username
or event.get("username") # Falls back to Federated ROPC username
or "", # The schema does not like null
"authority_type":
"authority_type": event.get(
"authority_type", # Honor caller's choice of authority_type
self.AuthorityType.ADFS if realm == "adfs"
else self.AuthorityType.MSSTS,
else self.AuthorityType.MSSTS),
# "client_info": response.get("client_info"), # Optional
}
self.modify(self.CredentialType.ACCOUNT, account, account)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,14 @@ def _test_acquire_token_interactive(
self, client_id=None, authority=None, scope=None, port=None,
username_uri="", # But you would want to provide one
data=None, # Needed by ssh-cert feature
prompt=None,
**ignored):
assert client_id and authority and scope
self.app = msal.PublicClientApplication(
client_id, authority=authority, http_client=MinimalHttpClient())
result = self.app.acquire_token_interactive(
scope,
prompt=prompt,
timeout=120,
port=port,
welcome_template= # This is an undocumented feature for testing
Expand Down Expand Up @@ -237,6 +239,7 @@ def test_ssh_cert_for_user(self):
scope=self.SCOPE,
data=self.DATA1,
username_uri="https://msidlab.com/api/user?usertype=cloud",
prompt="none" if msal.application._is_running_in_cloud_shell() else None,
) # It already tests reading AT from cache, and using RT to refresh
# acquire_token_silent() would work because we pass in the same key
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
Expand All @@ -254,6 +257,20 @@ def test_ssh_cert_for_user(self):
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])


@unittest.skipUnless(
msal.application._is_running_in_cloud_shell(),
"Manually run this test case from inside Cloud Shell")
class CloudShellTestCase(E2eTestCase):
app = msal.PublicClientApplication("client_id")
scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents
def test_access_token_should_be_obtained_for_a_supported_scope(self):
result = self.app.acquire_token_interactive(
[self.scope_that_requires_no_managed_device], prompt="none")
self.assertEqual(
"Bearer", result.get("token_type"), "Unexpected result: %s" % result)
self.assertIsNotNone(result.get("access_token"))


THIS_FOLDER = os.path.dirname(__file__)
CONFIG = os.path.join(THIS_FOLDER, "config.json")
@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
Expand Down

0 comments on commit 292e28b

Please sign in to comment.