Skip to content

Select JWK by kid to get around python-jose bug #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions fastapi_third_party_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)):
return f"Hello {authenticated_user.preferred_username}"
"""

from logging import getLogger
from typing import List
from typing import Optional
from typing import Type
Expand All @@ -33,14 +34,16 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)):
from fastapi.security import OAuth2
from fastapi.security import SecurityScopes
from jose import ExpiredSignatureError
from jose import JWTError
from jose import jwt
from jose.exceptions import JWTClaimsError
from jose.exceptions import JWTClaimsError, JWKError, JWTError, JWSError
from requests.exceptions import ConnectionError

from fastapi_third_party_auth import discovery
from fastapi_third_party_auth.grant_types import GrantType
from fastapi_third_party_auth.idtoken_types import IDToken

logger = getLogger(__name__)


class Auth(OAuth2):
def __init__(
Expand Down Expand Up @@ -81,8 +84,19 @@ def __init__(
self.client_id = client_id
self.idtoken_model = idtoken_model
self.scopes = scopes

self.discover = discovery.configure(cache_ttl=signature_cache_ttl)
self.grant_types = grant_types

try:
flows = self.get_flows()
except ConnectionError as e:
logger.warning("Could not discover OIDC flows %s", e)
flows = OAuthFlows()

super().__init__(scheme_name="OIDC", flows=flows, auto_error=False)

def get_flows(self) -> OAuthFlows:
oidc_discoveries = self.discover.auth_server(
openid_connect_url=self.openid_connect_url
)
Expand All @@ -91,36 +105,32 @@ def __init__(
# }

flows = OAuthFlows()
if GrantType.AUTHORIZATION_CODE in grant_types:
if GrantType.AUTHORIZATION_CODE in self.grant_types:
flows.authorizationCode = OAuthFlowAuthorizationCode(
authorizationUrl=self.discover.authorization_url(oidc_discoveries),
tokenUrl=self.discover.token_url(oidc_discoveries),
# scopes=scopes_dict,
)

if GrantType.CLIENT_CREDENTIALS in grant_types:
if GrantType.CLIENT_CREDENTIALS in self.grant_types:
flows.clientCredentials = OAuthFlowClientCredentials(
tokenUrl=self.discover.token_url(oidc_discoveries),
# scopes=scopes_dict,
)

if GrantType.PASSWORD in grant_types:
if GrantType.PASSWORD in self.grant_types:
flows.password = OAuthFlowPassword(
tokenUrl=self.discover.token_url(oidc_discoveries),
# scopes=scopes_dict,
)

if GrantType.IMPLICIT in grant_types:
if GrantType.IMPLICIT in self.grant_types:
flows.implicit = OAuthFlowImplicit(
authorizationUrl=self.discover.authorization_url(oidc_discoveries),
# scopes=scopes_dict,
)

super().__init__(
scheme_name="OIDC",
flows=flows,
auto_error=False,
)

return flows

async def __call__(self, request: Request) -> None:
return None
Expand Down Expand Up @@ -189,6 +199,33 @@ def optional(
auto_error=False,
)


def _find_key(self, token: str) -> dict:
oidc_discoveries = self.discover.auth_server(
openid_connect_url=self.openid_connect_url
)
try:
keys = self.discover.public_keys(oidc_discoveries)["keys"]
except KeyError as e:
raise JWKError("Badly formed JWKs_uri") from e

header = jwt.get_unverified_header(token)
try:
kid = header['kid']
except KeyError as e:
raise JWTError("field 'kid' is missing from JWT headers") from e

for key in keys:
try:
key_kid = key['kid']
except KeyError as e:
raise JWKError("field 'kid' is missing from JWK") from e
if key_kid == kid:
return key

raise JWKError(f"Could not find JWK 'kid'={kid}")


def authenticate_user(
self,
security_scopes: SecurityScopes,
Expand Down Expand Up @@ -222,12 +259,16 @@ def authenticate_user(
)
else:
return None

oidc_discoveries = self.discover.auth_server(
openid_connect_url=self.openid_connect_url
)
key = self.discover.public_keys(oidc_discoveries)

try:
oidc_discoveries = self.discover.auth_server(
openid_connect_url=self.openid_connect_url
)
except ConnectionError as e:
logger.warning("Could not reach auth server %e", e)
raise HTTPException(503, detail="Could not reach auth server") from e
algorithms = self.discover.signing_algos(oidc_discoveries)
key = self._find_key(authorization_credentials.credentials)

try:
id_token = jwt.decode(
Expand All @@ -245,7 +286,8 @@ def authenticate_user(
)

if (
type(id_token["aud"]) == list
"aud" in id_token
and type(id_token["aud"]) == list
and len(id_token["aud"]) >= 1
and "azp" not in id_token
):
Expand Down