Skip to content

[SSH] SSH support including ATs bound to keys #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def acquire_token_by_authorization_code(
# one scope. But, MSAL decorates your scope anyway, so they are never
# really empty.
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
data=dict(
Expand Down Expand Up @@ -396,6 +397,7 @@ def acquire_token_silent(
- None when cache lookup does not yield anything.
"""
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
Expand Down Expand Up @@ -424,15 +426,19 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
force_refresh=False, # type: Optional[boolean]
**kwargs):
if not force_refresh:
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query={
query={
"client_id": self.client_id,
"environment": authority.instance,
"realm": authority.tenant,
"home_account_id": (account or {}).get("home_account_id"),
})
}
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query=query)
now = time.time()
for entry in matches:
expires_in = int(entry["expires_on"]) - now
Expand Down Expand Up @@ -513,6 +519,20 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
if break_condition(response):
break

def _validate_ssh_cert_input_data(self, data):
if data.get("token_type") == "ssh-cert":
if not data.get("req_cnf"):
raise ValueError(
"When requesting an SSH certificate, "
"you must include a string parameter named 'req_cnf' "
"containing the public key in JWK format "
"(https://tools.ietf.org/html/rfc7517).")
if not data.get("key_id"):
raise ValueError(
"When requesting an SSH certificate, "
"you must include a string parameter named 'key_id' "
"which identifies the key in the 'req_cnf' argument.")


class PublicClientApplication(ClientApplication): # browser app or mobile app

Expand Down
3 changes: 3 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __add(self, event, now=None):
if "token_endpoint" in event:
_, environment, realm = canonicalize(event["token_endpoint"])
response = event.get("response", {})
data = event.get("data", {})
access_token = response.get("access_token")
refresh_token = response.get("refresh_token")
id_token = response.get("id_token")
Expand Down Expand Up @@ -165,6 +166,8 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}
if data.get("key_id"): # It happens in SSH-cert or POP scenario
at["key_id"] = data.get("key_id")
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)

if client_info:
Expand Down
52 changes: 48 additions & 4 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,23 @@ def test_username_password(self):
self.skipUnlessWithConfig(["client_id", "username", "password", "scope"])
self._test_username_password(**self.config)

def test_auth_code(self):
self.skipUnlessWithConfig(["client_id", "scope"])
def _get_app_and_auth_code(self):
from msal.oauth2cli.authcode import obtain_auth_code
self.app = msal.ClientApplication(
app = msal.ClientApplication(
self.config["client_id"],
client_credential=self.config.get("client_secret"),
authority=self.config.get("authority"))
port = self.config.get("listen_port", 44331)
redirect_uri = "http://localhost:%s" % port
auth_request_uri = self.app.get_authorization_request_url(
auth_request_uri = app.get_authorization_request_url(
self.config["scope"], redirect_uri=redirect_uri)
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
self.assertNotEqual(ac, None)
return (app, ac, redirect_uri)

def test_auth_code(self):
self.skipUnlessWithConfig(["client_id", "scope"])
(self.app, ac, redirect_uri) = self._get_app_and_auth_code()

result = self.app.acquire_token_by_authorization_code(
ac, self.config["scope"], redirect_uri=redirect_uri)
Expand All @@ -120,6 +124,46 @@ def test_auth_code(self):
error_description=result.get("error_description")))
self.assertCacheWorksForUser(result, self.config["scope"], username=None)


def test_ssh_cert(self):
self.skipUnlessWithConfig(["client_id", "scope"])

JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
JWK2 = """{"kty":"RSA", "n":"72u07mew8rw-ssw3tUs9clKstGO2lvD7ZNxJU7OPNKz5PGYx3gjkhUmtNah4I4FP0DuF1ogb_qSS5eD86w10Wb1ftjWcoY8zjNO9V3ph-Q2tMQWdDW5kLdeU3-EDzc0HQeou9E0udqmfQoPbuXFQcOkdcbh3eeYejs8sWn3TQprXRwGh_TRYi-CAurXXLxQ8rp-pltUVRIr1B63fXmXhMeCAGwCPEFX9FRRs-YHUszUJl9F9-E0nmdOitiAkKfCC9LhwB9_xKtjmHUM9VaEC9jWOcdvXZutwEoW2XPMOg0Ky-s197F9rfpgHle2gBrXsbvVMvS0D-wXg6vsq6BAHzQ", "e":"AQAB"}"""
data1 = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": JWK1}
ssh_test_slice = {
"dc": "prod-wst-test1",
"slice": "test",
"sshcrt": "true",
}

(self.app, ac, redirect_uri) = self._get_app_and_auth_code()

result = self.app.acquire_token_by_authorization_code(
ac, self.config["scope"], redirect_uri=redirect_uri, data=data1,
params=ssh_test_slice)
self.assertEqual("ssh-cert", result["token_type"])
logger.debug("%s.cache = %s",
self.id(), json.dumps(self.app.token_cache._cache, indent=4))

# acquire_token_silent() needs to be passed the same key to work
account = self.app.get_accounts()[0]
result_from_cache = self.app.acquire_token_silent(
self.config["scope"], account=account, data=data1)
self.assertIsNotNone(result_from_cache)
self.assertEqual(
result['access_token'], result_from_cache['access_token'],
"We should get the cached SSH-cert")

# refresh_token grant can fetch an ssh-cert bound to a different key
refreshed_ssh_cert = self.app.acquire_token_silent(
self.config["scope"], account=account, params=ssh_test_slice,
data={"token_type": "ssh-cert", "key_id": "key2", "req_cnf": JWK2})
self.assertIsNotNone(refreshed_ssh_cert)
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])


def test_client_secret(self):
self.skipUnlessWithConfig(["client_id", "client_secret"])
self.app = msal.ConfidentialClientApplication(
Expand Down
17 changes: 17 additions & 0 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,23 @@ def testAddByAdfs(self):
"appmetadata-fs.msidlab8.com-my_client_id")
)

def test_key_id_is_also_recorded(self):
my_key_id = "some_key_id_123"
self.cache.add({
"data": {"key_id": my_key_id},
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": self.build_response(
uid="uid", utid="utid", # client_info
expires_in=3600, access_token="an access token",
refresh_token="a refresh token"),
}, now=1000)
cached_key_id = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
{}).get("key_id")
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")


class SerializableTokenCacheTestCase(TokenCacheTestCase):
# Run all inherited test methods, and have extra check in tearDown()
Expand Down