From f8bc697022563ed63ca47a2613d179b989e74932 Mon Sep 17 00:00:00 2001 From: David Schultz Date: Thu, 8 Aug 2024 16:54:43 -0500 Subject: [PATCH] add pkce to public device grants (#149) * add pkce to public device grants * update dependencies*.log files(s) * update dependencies*.log files(s) --------- Co-authored-by: github-actions --- dependencies-dev.log | 28 +++--- dependencies-mypy.log | 59 +++++------ dependencies-openapi.log | 22 ++--- dependencies-telemetry.log | 31 +++--- dependencies-tests.log | 26 ++--- dependencies.log | 12 +-- examples/get_device_credentials_token.py | 3 + rest_tools/client/device_client.py | 121 ++++++++++++----------- rest_tools/server/handler.py | 22 +---- rest_tools/utils/pkce.py | 23 +++++ 10 files changed, 185 insertions(+), 162 deletions(-) create mode 100644 rest_tools/utils/pkce.py diff --git a/dependencies-dev.log b/dependencies-dev.log index d804b1b5..fd5d8b1c 100644 --- a/dependencies-dev.log +++ b/dependencies-dev.log @@ -4,19 +4,19 @@ # # pip-compile --extra=dev --output-file=dependencies-dev.log # -cachetools==5.3.3 +cachetools==5.4.0 # via wipac-rest-tools (setup.py) -certifi==2024.6.2 +certifi==2024.7.4 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests -cryptography==42.0.8 +cryptography==43.0.0 # via # pyjwt # wipac-rest-tools (setup.py) -flake8==7.1.0 +flake8==7.1.1 # via wipac-rest-tools (setup.py) httpretty==1.1.4 # via wipac-rest-tools (setup.py) @@ -26,7 +26,7 @@ iniconfig==2.0.0 # via pytest mccabe==0.7.0 # via flake8 -mypy==1.10.1 +mypy==1.11.1 # via wipac-rest-tools (setup.py) mypy-extensions==1.0.0 # via mypy @@ -34,22 +34,22 @@ packaging==24.1 # via pytest pluggy==1.5.0 # via pytest -pycodestyle==2.12.0 +pycodestyle==2.12.1 # via flake8 pycparser==2.22 # via cffi pyflakes==3.2.0 # via flake8 -pyjwt[crypto]==2.8.0 +pyjwt[crypto]==2.9.0 # via wipac-rest-tools (setup.py) pypng==0.20220715.0 # via qrcode -pytest==8.2.2 +pytest==8.3.2 # via # pytest-asyncio # pytest-mock # wipac-rest-tools (setup.py) -pytest-asyncio==0.23.7 +pytest-asyncio==0.23.8 # via wipac-rest-tools (setup.py) pytest-mock==3.14.0 # via wipac-rest-tools (setup.py) @@ -65,11 +65,11 @@ requests-futures==1.0.1 # via wipac-rest-tools (setup.py) requests-mock==1.12.1 # via wipac-rest-tools (setup.py) -ruff==0.5.0 +ruff==0.5.7 # via wipac-rest-tools (setup.py) tornado==6.4.1 # via wipac-rest-tools (setup.py) -types-requests==2.32.0.20240622 +types-requests==2.32.0.20240712 # via wipac-rest-tools (setup.py) typing-extensions==4.12.2 # via @@ -81,7 +81,7 @@ urllib3==2.2.2 # requests # types-requests # wipac-rest-tools (setup.py) -wheel==0.43.0 +wheel==0.44.0 # via wipac-rest-tools (setup.py) -wipac-dev-tools==1.10.6 +wipac-dev-tools==1.12.0 # via wipac-rest-tools (setup.py) diff --git a/dependencies-mypy.log b/dependencies-mypy.log index 4caa21bc..55452e81 100644 --- a/dependencies-mypy.log +++ b/dependencies-mypy.log @@ -4,15 +4,15 @@ # # pip-compile --extra=mypy --output-file=dependencies-mypy.log # -attrs==23.2.0 +attrs==24.2.0 # via # jsonschema # referencing -cachetools==5.3.3 +cachetools==5.4.0 # via wipac-rest-tools (setup.py) -certifi==2024.6.2 +certifi==2024.7.4 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests @@ -26,13 +26,13 @@ colorama==0.4.6 # via crayons coloredlogs==15.0.1 # via wipac-telemetry -coverage[toml]==7.5.4 +coverage[toml]==7.6.1 # via # pytest-cov # wipac-rest-tools (setup.py) crayons==0.4.0 # via pycycle -cryptography==42.0.8 +cryptography==43.0.0 # via # pyjwt # wipac-rest-tools (setup.py) @@ -40,13 +40,14 @@ deprecated==1.2.14 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-http -flake8==7.1.0 + # opentelemetry-semantic-conventions +flake8==7.1.1 # via wipac-rest-tools (setup.py) googleapis-common-protos==1.59.1 # via # opentelemetry-exporter-jaeger-proto-grpc # opentelemetry-exporter-otlp-proto-http -grpcio==1.64.1 +grpcio==1.65.4 # via opentelemetry-exporter-jaeger-proto-grpc httpretty==1.1.4 # via wipac-rest-tools (setup.py) @@ -54,7 +55,7 @@ humanfriendly==10.0 # via coloredlogs idna==3.7 # via requests -importlib-metadata==7.1.0 +importlib-metadata==8.0.0 # via opentelemetry-api iniconfig==2.0.0 # via pytest @@ -62,7 +63,7 @@ isodate==0.6.1 # via openapi-core jinja2==3.1.4 # via click-completion -jsonschema==4.22.0 +jsonschema==4.23.0 # via # openapi-core # openapi-schema-validator @@ -83,9 +84,9 @@ markupsafe==2.1.5 # werkzeug mccabe==0.7.0 # via flake8 -more-itertools==10.3.0 +more-itertools==10.4.0 # via openapi-core -mypy==1.10.1 +mypy==1.11.1 # via wipac-rest-tools (setup.py) mypy-extensions==1.0.0 # via mypy @@ -97,7 +98,7 @@ openapi-schema-validator==0.6.2 # openapi-spec-validator openapi-spec-validator==0.7.1 # via openapi-core -opentelemetry-api==1.25.0 +opentelemetry-api==1.26.0 # via # opentelemetry-exporter-jaeger-proto-grpc # opentelemetry-exporter-jaeger-thrift @@ -111,21 +112,21 @@ opentelemetry-exporter-jaeger-proto-grpc==1.21.0 # via opentelemetry-exporter-jaeger opentelemetry-exporter-jaeger-thrift==1.21.0 # via opentelemetry-exporter-jaeger -opentelemetry-exporter-otlp-proto-common==1.25.0 +opentelemetry-exporter-otlp-proto-common==1.26.0 # via opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-http==1.25.0 +opentelemetry-exporter-otlp-proto-http==1.26.0 # via wipac-telemetry -opentelemetry-proto==1.25.0 +opentelemetry-proto==1.26.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.25.0 +opentelemetry-sdk==1.26.0 # via # opentelemetry-exporter-jaeger-proto-grpc # opentelemetry-exporter-jaeger-thrift # opentelemetry-exporter-otlp-proto-http # wipac-telemetry -opentelemetry-semantic-conventions==0.46b0 +opentelemetry-semantic-conventions==0.47b0 # via opentelemetry-sdk packaging==24.1 # via pytest @@ -135,12 +136,12 @@ pathable==0.4.3 # via jsonschema-path pluggy==1.5.0 # via pytest -protobuf==4.25.3 +protobuf==4.25.4 # via # googleapis-common-protos # opentelemetry-proto # wipac-telemetry -pycodestyle==2.12.0 +pycodestyle==2.12.1 # via flake8 pycparser==2.22 # via cffi @@ -148,24 +149,24 @@ pycycle==0.0.8 # via wipac-rest-tools (setup.py) pyflakes==3.2.0 # via flake8 -pyjwt[crypto]==2.8.0 +pyjwt[crypto]==2.9.0 # via wipac-rest-tools (setup.py) pypng==0.20220715.0 # via qrcode -pytest==8.2.2 +pytest==8.3.2 # via # pycycle # pytest-asyncio # pytest-cov # pytest-mock # wipac-rest-tools (setup.py) -pytest-asyncio==0.23.7 +pytest-asyncio==0.23.8 # via wipac-rest-tools (setup.py) pytest-cov==5.0.0 # via wipac-rest-tools (setup.py) pytest-mock==3.14.0 # via wipac-rest-tools (setup.py) -pyyaml==6.0.1 +pyyaml==6.0.2 # via jsonschema-path qrcode==7.4.2 # via wipac-rest-tools (setup.py) @@ -188,11 +189,11 @@ requests-mock==1.12.1 # via wipac-rest-tools (setup.py) rfc3339-validator==0.1.4 # via openapi-schema-validator -rpds-py==0.18.1 +rpds-py==0.20.0 # via # jsonschema # referencing -ruff==0.5.0 +ruff==0.5.7 # via wipac-rest-tools (setup.py) shellingham==1.5.4 # via click-completion @@ -206,7 +207,7 @@ thrift==0.20.0 # via opentelemetry-exporter-jaeger-thrift tornado==6.4.1 # via wipac-rest-tools (setup.py) -types-requests==2.32.0.20240622 +types-requests==2.32.0.20240712 # via wipac-rest-tools (setup.py) typing-extensions==4.12.2 # via @@ -222,9 +223,9 @@ urllib3==2.2.2 # wipac-rest-tools (setup.py) werkzeug==3.0.3 # via openapi-core -wheel==0.43.0 +wheel==0.44.0 # via wipac-rest-tools (setup.py) -wipac-dev-tools==1.10.6 +wipac-dev-tools==1.12.0 # via # wipac-rest-tools (setup.py) # wipac-telemetry diff --git a/dependencies-openapi.log b/dependencies-openapi.log index 9337e6cb..aa274c55 100644 --- a/dependencies-openapi.log +++ b/dependencies-openapi.log @@ -4,25 +4,25 @@ # # pip-compile --extra=openapi --output-file=dependencies-openapi.log # -attrs==23.2.0 +attrs==24.2.0 # via # jsonschema # referencing -cachetools==5.3.3 +cachetools==5.4.0 # via wipac-rest-tools (setup.py) -certifi==2024.6.2 +certifi==2024.7.4 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests -cryptography==42.0.8 +cryptography==43.0.0 # via pyjwt idna==3.7 # via requests isodate==0.6.1 # via openapi-core -jsonschema==4.22.0 +jsonschema==4.23.0 # via # openapi-core # openapi-schema-validator @@ -39,7 +39,7 @@ lazy-object-proxy==1.10.0 # via openapi-spec-validator markupsafe==2.1.5 # via werkzeug -more-itertools==10.3.0 +more-itertools==10.4.0 # via openapi-core openapi-core==0.19.2 # via wipac-rest-tools (setup.py) @@ -55,11 +55,11 @@ pathable==0.4.3 # via jsonschema-path pycparser==2.22 # via cffi -pyjwt[crypto]==2.8.0 +pyjwt[crypto]==2.9.0 # via wipac-rest-tools (setup.py) pypng==0.20220715.0 # via qrcode -pyyaml==6.0.1 +pyyaml==6.0.2 # via jsonschema-path qrcode==7.4.2 # via wipac-rest-tools (setup.py) @@ -78,7 +78,7 @@ requests-futures==1.0.1 # via wipac-rest-tools (setup.py) rfc3339-validator==0.1.4 # via openapi-schema-validator -rpds-py==0.18.1 +rpds-py==0.20.0 # via # jsonschema # referencing @@ -98,5 +98,5 @@ urllib3==2.2.2 # wipac-rest-tools (setup.py) werkzeug==3.0.3 # via openapi-core -wipac-dev-tools==1.10.6 +wipac-dev-tools==1.12.0 # via wipac-rest-tools (setup.py) diff --git a/dependencies-telemetry.log b/dependencies-telemetry.log index 18e60e54..169a6bdf 100644 --- a/dependencies-telemetry.log +++ b/dependencies-telemetry.log @@ -4,35 +4,36 @@ # # pip-compile --extra=telemetry --output-file=dependencies-telemetry.log # -cachetools==5.3.3 +cachetools==5.4.0 # via wipac-rest-tools (setup.py) -certifi==2024.6.2 +certifi==2024.7.4 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests coloredlogs==15.0.1 # via wipac-telemetry -cryptography==42.0.8 +cryptography==43.0.0 # via pyjwt deprecated==1.2.14 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-http + # opentelemetry-semantic-conventions googleapis-common-protos==1.59.1 # via # opentelemetry-exporter-jaeger-proto-grpc # opentelemetry-exporter-otlp-proto-http -grpcio==1.64.1 +grpcio==1.65.4 # via opentelemetry-exporter-jaeger-proto-grpc humanfriendly==10.0 # via coloredlogs idna==3.7 # via requests -importlib-metadata==7.1.0 +importlib-metadata==8.0.0 # via opentelemetry-api -opentelemetry-api==1.25.0 +opentelemetry-api==1.26.0 # via # opentelemetry-exporter-jaeger-proto-grpc # opentelemetry-exporter-jaeger-thrift @@ -46,30 +47,30 @@ opentelemetry-exporter-jaeger-proto-grpc==1.21.0 # via opentelemetry-exporter-jaeger opentelemetry-exporter-jaeger-thrift==1.21.0 # via opentelemetry-exporter-jaeger -opentelemetry-exporter-otlp-proto-common==1.25.0 +opentelemetry-exporter-otlp-proto-common==1.26.0 # via opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-http==1.25.0 +opentelemetry-exporter-otlp-proto-http==1.26.0 # via wipac-telemetry -opentelemetry-proto==1.25.0 +opentelemetry-proto==1.26.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.25.0 +opentelemetry-sdk==1.26.0 # via # opentelemetry-exporter-jaeger-proto-grpc # opentelemetry-exporter-jaeger-thrift # opentelemetry-exporter-otlp-proto-http # wipac-telemetry -opentelemetry-semantic-conventions==0.46b0 +opentelemetry-semantic-conventions==0.47b0 # via opentelemetry-sdk -protobuf==4.25.3 +protobuf==4.25.4 # via # googleapis-common-protos # opentelemetry-proto # wipac-telemetry pycparser==2.22 # via cffi -pyjwt[crypto]==2.8.0 +pyjwt[crypto]==2.9.0 # via wipac-rest-tools (setup.py) pypng==0.20220715.0 # via qrcode @@ -99,7 +100,7 @@ urllib3==2.2.2 # via # requests # wipac-rest-tools (setup.py) -wipac-dev-tools==1.10.6 +wipac-dev-tools==1.12.0 # via # wipac-rest-tools (setup.py) # wipac-telemetry diff --git a/dependencies-tests.log b/dependencies-tests.log index 2e77cffa..b9448230 100644 --- a/dependencies-tests.log +++ b/dependencies-tests.log @@ -4,11 +4,11 @@ # # pip-compile --extra=tests --output-file=dependencies-tests.log # -cachetools==5.3.3 +cachetools==5.4.0 # via wipac-rest-tools (setup.py) -certifi==2024.6.2 +certifi==2024.7.4 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests @@ -20,15 +20,15 @@ click-completion==0.5.2 # via pycycle colorama==0.4.6 # via crayons -coverage[toml]==7.5.4 +coverage[toml]==7.6.1 # via # pytest-cov # wipac-rest-tools (setup.py) crayons==0.4.0 # via pycycle -cryptography==42.0.8 +cryptography==43.0.0 # via pyjwt -flake8==7.1.0 +flake8==7.1.1 # via wipac-rest-tools (setup.py) httpretty==1.1.4 # via wipac-rest-tools (setup.py) @@ -46,7 +46,7 @@ packaging==24.1 # via pytest pluggy==1.5.0 # via pytest -pycodestyle==2.12.0 +pycodestyle==2.12.1 # via flake8 pycparser==2.22 # via cffi @@ -54,18 +54,18 @@ pycycle==0.0.8 # via wipac-rest-tools (setup.py) pyflakes==3.2.0 # via flake8 -pyjwt[crypto]==2.8.0 +pyjwt[crypto]==2.9.0 # via wipac-rest-tools (setup.py) pypng==0.20220715.0 # via qrcode -pytest==8.2.2 +pytest==8.3.2 # via # pycycle # pytest-asyncio # pytest-cov # pytest-mock # wipac-rest-tools (setup.py) -pytest-asyncio==0.23.7 +pytest-asyncio==0.23.8 # via wipac-rest-tools (setup.py) pytest-cov==5.0.0 # via wipac-rest-tools (setup.py) @@ -83,7 +83,7 @@ requests-futures==1.0.1 # via wipac-rest-tools (setup.py) requests-mock==1.12.1 # via wipac-rest-tools (setup.py) -ruff==0.5.0 +ruff==0.5.7 # via wipac-rest-tools (setup.py) shellingham==1.5.4 # via click-completion @@ -91,7 +91,7 @@ six==1.16.0 # via click-completion tornado==6.4.1 # via wipac-rest-tools (setup.py) -types-requests==2.32.0.20240622 +types-requests==2.32.0.20240712 # via wipac-rest-tools (setup.py) typing-extensions==4.12.2 # via @@ -102,5 +102,5 @@ urllib3==2.2.2 # requests # types-requests # wipac-rest-tools (setup.py) -wipac-dev-tools==1.10.6 +wipac-dev-tools==1.12.0 # via wipac-rest-tools (setup.py) diff --git a/dependencies.log b/dependencies.log index 91a808da..4d08f573 100644 --- a/dependencies.log +++ b/dependencies.log @@ -4,21 +4,21 @@ # # pip-compile --output-file=dependencies.log # -cachetools==5.3.3 +cachetools==5.4.0 # via wipac-rest-tools (setup.py) -certifi==2024.6.2 +certifi==2024.7.4 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests -cryptography==42.0.8 +cryptography==43.0.0 # via pyjwt idna==3.7 # via requests pycparser==2.22 # via cffi -pyjwt[crypto]==2.8.0 +pyjwt[crypto]==2.9.0 # via wipac-rest-tools (setup.py) pypng==0.20220715.0 # via qrcode @@ -41,5 +41,5 @@ urllib3==2.2.2 # via # requests # wipac-rest-tools (setup.py) -wipac-dev-tools==1.10.6 +wipac-dev-tools==1.12.0 # via wipac-rest-tools (setup.py) diff --git a/examples/get_device_credentials_token.py b/examples/get_device_credentials_token.py index 152c95f7..e0948188 100644 --- a/examples/get_device_credentials_token.py +++ b/examples/get_device_credentials_token.py @@ -1,4 +1,5 @@ import argparse +import logging from rest_tools.client import SavedDeviceGrantAuth @@ -14,6 +15,8 @@ def main(): parser.add_argument('--address', default='https://keycloak.icecube.wisc.edu/auth/realms/IceCube', help='OAuth2 server address') parser.add_argument('client_id', help='client id') + logging.basicConfig(level=logging.DEBUG) + args = parser.parse_args() kwargs = vars(args) print('access token:', get_token(**kwargs)) diff --git a/rest_tools/client/device_client.py b/rest_tools/client/device_client.py index f37c3b0e..fcc0f960 100644 --- a/rest_tools/client/device_client.py +++ b/rest_tools/client/device_client.py @@ -14,6 +14,7 @@ from .openid_client import OpenIDRestClient from ..utils.auth import OpenIDAuth +from ..utils.pkce import PKCEMixin def _print_qrcode(req: Dict[str, str]) -> None: @@ -58,53 +59,29 @@ def _print_qrcode(req: Dict[str, str]) -> None: # fmt:on -def _perform_device_grant( - logger: logging.Logger, - device_url: str, - token_url: str, - client_id: str, - client_secret: Optional[str] = None, - scopes: Optional[List[str]] = None, -) -> str: - args = { - 'client_id': client_id, - 'scope': 'offline_access ' + (' '.join(scopes) if scopes else ''), - } - if client_secret: - args['client_secret'] = client_secret - - try: - r = requests.post(device_url, data=args) - r.raise_for_status() - req = r.json() - except requests.exceptions.HTTPError as exc: - logger.debug('%r', exc.response.text) - try: - req = exc.response.json() - except Exception: - req = {} - error = req.get('error', '') - raise RuntimeError(f'Device authorization failed: {error}') from exc - except Exception as exc: - raise RuntimeError('Device authorization failed') from exc - - logger.debug('Device auth in progress') - - _print_qrcode(req) - - args = { - 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code', - 'device_code': req['device_code'], - 'client_id': client_id, - } - if client_secret: - args['client_secret'] = client_secret - - sleep_time = int(req.get('interval', 5)) - while True: - time.sleep(sleep_time) +class CommonDeviceGrant(PKCEMixin): + def perform_device_grant( + self, + logger: logging.Logger, + device_url: str, + token_url: str, + client_id: str, + client_secret: Optional[str] = None, + scopes: Optional[List[str]] = None, + ) -> str: + args = { + 'client_id': client_id, + 'scope': 'offline_access ' + (' '.join(scopes) if scopes else ''), + } + if client_secret: + args['client_secret'] = client_secret + else: + code_challenge = self.create_pkce_challenge() + args['code_challenge'] = code_challenge + args['code_challenge_method'] = 'S256' + try: - r = requests.post(token_url, data=args) + r = requests.post(device_url, data=args) r.raise_for_status() req = r.json() except requests.exceptions.HTTPError as exc: @@ -114,17 +91,49 @@ def _perform_device_grant( except Exception: req = {} error = req.get('error', '') - if error == 'authorization_pending': - continue - elif error == 'slow_down': - sleep_time += 5 - continue raise RuntimeError(f'Device authorization failed: {error}') from exc except Exception as exc: raise RuntimeError('Device authorization failed') from exc - break - return req['refresh_token'] + logger.debug('Device auth in progress') + + _print_qrcode(req) + + args = { + 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code', + 'device_code': req['device_code'], + 'client_id': client_id, + } + if client_secret: + args['client_secret'] = client_secret + else: + args['code_verifier'] = self.get_pkce_verifier(code_challenge) + + sleep_time = int(req.get('interval', 5)) + while True: + time.sleep(sleep_time) + try: + r = requests.post(token_url, data=args) + r.raise_for_status() + req = r.json() + except requests.exceptions.HTTPError as exc: + logger.debug('%r', exc.response.text) + try: + req = exc.response.json() + except Exception: + req = {} + error = req.get('error', '') + if error == 'authorization_pending': + continue + elif error == 'slow_down': + sleep_time += 5 + continue + raise RuntimeError(f'Device authorization failed: {error}') from exc + except Exception as exc: + raise RuntimeError('Device authorization failed') from exc + break + + return req['refresh_token'] def DeviceGrantAuth( @@ -153,7 +162,8 @@ def DeviceGrantAuth( raise RuntimeError('Device grant not supported by server') endpoint: str = auth.provider_info['device_authorization_endpoint'] # type: ignore - refresh_token = _perform_device_grant( + device = CommonDeviceGrant() + refresh_token = device.perform_device_grant( logger, endpoint, auth.token_url, client_id, client_secret, scopes ) @@ -231,7 +241,8 @@ def update_func(access, refresh): raise RuntimeError('Device grant not supported by server') endpoint: str = auth.provider_info['device_authorization_endpoint'] # type: ignore - refresh_token = _perform_device_grant( + device = CommonDeviceGrant() + refresh_token = device.perform_device_grant( logger, endpoint, auth.token_url, client_id, client_secret, scopes ) diff --git a/rest_tools/server/handler.py b/rest_tools/server/handler.py index c416c954..f181dc30 100644 --- a/rest_tools/server/handler.py +++ b/rest_tools/server/handler.py @@ -6,15 +6,13 @@ import base64 import functools -import hashlib import hmac import json import logging -import secrets import time import urllib.parse from collections import defaultdict -from typing import Any, Dict, MutableMapping, Union +from typing import Any, Dict, Union import rest_tools import tornado.escape @@ -22,12 +20,12 @@ import tornado.httpclient import tornado.httputil import tornado.web -from cachetools import TTLCache from tornado.auth import OAuth2Mixin from .. import telemetry as wtt from ..utils.auth import Auth, OpenIDAuth from ..utils.json_util import json_decode +from ..utils.pkce import PKCEMixin from .decorators import catch_error from .stats import RouteStats @@ -324,7 +322,7 @@ def clear_tokens(self): self.clear_cookie('user_info') -class OpenIDLoginHandler(OpenIDCookieHandlerMixin, OAuth2Mixin, RestHandler): +class OpenIDLoginHandler(OpenIDCookieHandlerMixin, OAuth2Mixin, PKCEMixin, RestHandler): """Handle OpenID Connect logins. Should be combined with an appropriate mixin to store the token(s). @@ -332,7 +330,6 @@ class OpenIDLoginHandler(OpenIDCookieHandlerMixin, OAuth2Mixin, RestHandler): Requires the `login_url` application setting to be a full url. """ - _pkcs_challenges: MutableMapping[str, str] = TTLCache(maxsize=10000, ttl=3600) def initialize(self, oauth_client_id, oauth_client_secret, oauth_client_scope=None, **kwargs): super().initialize(**kwargs) @@ -354,19 +351,6 @@ def initialize(self, oauth_client_id, oauth_client_secret, oauth_client_scope=No scopes.add('offline_access') self.oauth_client_scope = list(scopes) - @classmethod - def create_pkce_challenge(cls) -> str: - code_verifier = secrets.token_urlsafe(64) - code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).decode('utf-8').split('=')[0] - cls._pkcs_challenges[code_challenge] = code_verifier - return code_challenge - - @classmethod - def get_pkce_verifier(cls, challenge: str) -> str: - if challenge in cls._pkcs_challenges: - return cls._pkcs_challenges[challenge] - raise KeyError('invalid pkce challenge') - async def get_authenticated_user( self, redirect_uri: str, code: str, state: Dict[str, Any] ) -> Dict[str, Any]: diff --git a/rest_tools/utils/pkce.py b/rest_tools/utils/pkce.py new file mode 100644 index 00000000..c9e55e40 --- /dev/null +++ b/rest_tools/utils/pkce.py @@ -0,0 +1,23 @@ +import base64 +import hashlib +import secrets +from typing import MutableMapping + +from cachetools import TTLCache + + +class PKCEMixin: + _pkcs_challenges: MutableMapping[str, str] = TTLCache(maxsize=10000, ttl=3600) + + @classmethod + def create_pkce_challenge(cls) -> str: + code_verifier = secrets.token_urlsafe(64) + code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).decode('utf-8').split('=')[0] + cls._pkcs_challenges[code_challenge] = code_verifier + return code_challenge + + @classmethod + def get_pkce_verifier(cls, challenge: str) -> str: + if challenge in cls._pkcs_challenges: + return cls._pkcs_challenges[challenge] + raise KeyError('invalid pkce challenge')