Skip to content

Commit 4550793

Browse files
bgavrilMSrayluo
authored andcommitted
[SSH] SSH support including ATs bound to keys (#102)
* [SSH] SSH support including ATs bound to keys * Fix test * Do not use casefold() * Move _validate_ssh_cert_input_data() to outer layer This at least avoids the performance penalty in those implicit loops inside acquire_token_silent(). * Code style "Don't use spaces around the = sign when used to indicate a keyword argument" - quoted from https://www.python.org/dev/peps/pep-0008/#other-recommendations * General tidy up * Move AT key_id test into its own unit test * Remove an extra space * Fix typo introduced in online editing via browser :-(
1 parent 97d29b1 commit 4550793

File tree

4 files changed

+93
-9
lines changed

4 files changed

+93
-9
lines changed

msal/application.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def acquire_token_by_authorization_code(
269269
# one scope. But, MSAL decorates your scope anyway, so they are never
270270
# really empty.
271271
assert isinstance(scopes, list), "Invalid parameter type"
272+
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
272273
return self.client.obtain_token_by_authorization_code(
273274
code, redirect_uri=redirect_uri,
274275
data=dict(
@@ -396,6 +397,7 @@ def acquire_token_silent(
396397
- None when cache lookup does not yield anything.
397398
"""
398399
assert isinstance(scopes, list), "Invalid parameter type"
400+
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
399401
if authority:
400402
warnings.warn("We haven't decided how/if this method will accept authority parameter")
401403
# the_authority = Authority(
@@ -424,15 +426,19 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
424426
force_refresh=False, # type: Optional[boolean]
425427
**kwargs):
426428
if not force_refresh:
427-
matches = self.token_cache.find(
428-
self.token_cache.CredentialType.ACCESS_TOKEN,
429-
target=scopes,
430-
query={
429+
query={
431430
"client_id": self.client_id,
432431
"environment": authority.instance,
433432
"realm": authority.tenant,
434433
"home_account_id": (account or {}).get("home_account_id"),
435-
})
434+
}
435+
key_id = kwargs.get("data", {}).get("key_id")
436+
if key_id: # Some token types (SSH-certs, POP) are bound to a key
437+
query["key_id"] = key_id
438+
matches = self.token_cache.find(
439+
self.token_cache.CredentialType.ACCESS_TOKEN,
440+
target=scopes,
441+
query=query)
436442
now = time.time()
437443
for entry in matches:
438444
expires_in = int(entry["expires_on"]) - now
@@ -513,6 +519,20 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
513519
if break_condition(response):
514520
break
515521

522+
def _validate_ssh_cert_input_data(self, data):
523+
if data.get("token_type") == "ssh-cert":
524+
if not data.get("req_cnf"):
525+
raise ValueError(
526+
"When requesting an SSH certificate, "
527+
"you must include a string parameter named 'req_cnf' "
528+
"containing the public key in JWK format "
529+
"(https://tools.ietf.org/html/rfc7517).")
530+
if not data.get("key_id"):
531+
raise ValueError(
532+
"When requesting an SSH certificate, "
533+
"you must include a string parameter named 'key_id' "
534+
"which identifies the key in the 'req_cnf' argument.")
535+
516536

517537
class PublicClientApplication(ClientApplication): # browser app or mobile app
518538

msal/token_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __add(self, event, now=None):
127127
if "token_endpoint" in event:
128128
_, environment, realm = canonicalize(event["token_endpoint"])
129129
response = event.get("response", {})
130+
data = event.get("data", {})
130131
access_token = response.get("access_token")
131132
refresh_token = response.get("refresh_token")
132133
id_token = response.get("id_token")
@@ -165,6 +166,8 @@ def __add(self, event, now=None):
165166
"expires_on": str(now + expires_in), # Same here
166167
"extended_expires_on": str(now + ext_expires_in) # Same here
167168
}
169+
if data.get("key_id"): # It happens in SSH-cert or POP scenario
170+
at["key_id"] = data.get("key_id")
168171
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
169172

170173
if client_info:

tests/test_e2e.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,23 @@ def test_username_password(self):
9494
self.skipUnlessWithConfig(["client_id", "username", "password", "scope"])
9595
self._test_username_password(**self.config)
9696

97-
def test_auth_code(self):
98-
self.skipUnlessWithConfig(["client_id", "scope"])
97+
def _get_app_and_auth_code(self):
9998
from msal.oauth2cli.authcode import obtain_auth_code
100-
self.app = msal.ClientApplication(
99+
app = msal.ClientApplication(
101100
self.config["client_id"],
102101
client_credential=self.config.get("client_secret"),
103102
authority=self.config.get("authority"))
104103
port = self.config.get("listen_port", 44331)
105104
redirect_uri = "http://localhost:%s" % port
106-
auth_request_uri = self.app.get_authorization_request_url(
105+
auth_request_uri = app.get_authorization_request_url(
107106
self.config["scope"], redirect_uri=redirect_uri)
108107
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
109108
self.assertNotEqual(ac, None)
109+
return (app, ac, redirect_uri)
110+
111+
def test_auth_code(self):
112+
self.skipUnlessWithConfig(["client_id", "scope"])
113+
(self.app, ac, redirect_uri) = self._get_app_and_auth_code()
110114

111115
result = self.app.acquire_token_by_authorization_code(
112116
ac, self.config["scope"], redirect_uri=redirect_uri)
@@ -120,6 +124,46 @@ def test_auth_code(self):
120124
error_description=result.get("error_description")))
121125
self.assertCacheWorksForUser(result, self.config["scope"], username=None)
122126

127+
128+
def test_ssh_cert(self):
129+
self.skipUnlessWithConfig(["client_id", "scope"])
130+
131+
JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
132+
JWK2 = """{"kty":"RSA", "n":"72u07mew8rw-ssw3tUs9clKstGO2lvD7ZNxJU7OPNKz5PGYx3gjkhUmtNah4I4FP0DuF1ogb_qSS5eD86w10Wb1ftjWcoY8zjNO9V3ph-Q2tMQWdDW5kLdeU3-EDzc0HQeou9E0udqmfQoPbuXFQcOkdcbh3eeYejs8sWn3TQprXRwGh_TRYi-CAurXXLxQ8rp-pltUVRIr1B63fXmXhMeCAGwCPEFX9FRRs-YHUszUJl9F9-E0nmdOitiAkKfCC9LhwB9_xKtjmHUM9VaEC9jWOcdvXZutwEoW2XPMOg0Ky-s197F9rfpgHle2gBrXsbvVMvS0D-wXg6vsq6BAHzQ", "e":"AQAB"}"""
133+
data1 = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": JWK1}
134+
ssh_test_slice = {
135+
"dc": "prod-wst-test1",
136+
"slice": "test",
137+
"sshcrt": "true",
138+
}
139+
140+
(self.app, ac, redirect_uri) = self._get_app_and_auth_code()
141+
142+
result = self.app.acquire_token_by_authorization_code(
143+
ac, self.config["scope"], redirect_uri=redirect_uri, data=data1,
144+
params=ssh_test_slice)
145+
self.assertEqual("ssh-cert", result["token_type"])
146+
logger.debug("%s.cache = %s",
147+
self.id(), json.dumps(self.app.token_cache._cache, indent=4))
148+
149+
# acquire_token_silent() needs to be passed the same key to work
150+
account = self.app.get_accounts()[0]
151+
result_from_cache = self.app.acquire_token_silent(
152+
self.config["scope"], account=account, data=data1)
153+
self.assertIsNotNone(result_from_cache)
154+
self.assertEqual(
155+
result['access_token'], result_from_cache['access_token'],
156+
"We should get the cached SSH-cert")
157+
158+
# refresh_token grant can fetch an ssh-cert bound to a different key
159+
refreshed_ssh_cert = self.app.acquire_token_silent(
160+
self.config["scope"], account=account, params=ssh_test_slice,
161+
data={"token_type": "ssh-cert", "key_id": "key2", "req_cnf": JWK2})
162+
self.assertIsNotNone(refreshed_ssh_cert)
163+
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
164+
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])
165+
166+
123167
def test_client_secret(self):
124168
self.skipUnlessWithConfig(["client_id", "client_secret"])
125169
self.app = msal.ConfidentialClientApplication(

tests/test_token_cache.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,23 @@ def testAddByAdfs(self):
205205
"appmetadata-fs.msidlab8.com-my_client_id")
206206
)
207207

208+
def test_key_id_is_also_recorded(self):
209+
my_key_id = "some_key_id_123"
210+
self.cache.add({
211+
"data": {"key_id": my_key_id},
212+
"client_id": "my_client_id",
213+
"scope": ["s2", "s1", "s3"], # Not in particular order
214+
"token_endpoint": "https://login.example.com/contoso/v2/token",
215+
"response": self.build_response(
216+
uid="uid", utid="utid", # client_info
217+
expires_in=3600, access_token="an access token",
218+
refresh_token="a refresh token"),
219+
}, now=1000)
220+
cached_key_id = self.cache._cache["AccessToken"].get(
221+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
222+
{}).get("key_id")
223+
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
224+
208225

209226
class SerializableTokenCacheTestCase(TokenCacheTestCase):
210227
# Run all inherited test methods, and have extra check in tearDown()

0 commit comments

Comments
 (0)