Skip to content

Commit

Permalink
fixup! fixup! Issue #237 WIP 3
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Apr 25, 2023
1 parent 9a023ca commit 0c346a9
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions openeo/rest/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import hashlib
import http.server
import inspect
import json
import logging
import math
Expand Down Expand Up @@ -661,6 +662,14 @@ def _get_token_endpoint_post_data(self) -> dict:
)


def _like_print(display: Callable) -> Callable:
"""Ensure that display function supports an `end` argument like `print`"""
if display is print or "end" in inspect.signature(display).parameters:
return display
else:
return lambda *args, end="\n", **kwargs: display(*args, **kwargs)


class _BasicDeviceCodePollUi:
"""
Basic (print + carriage return) implementation of the device code
Expand All @@ -672,12 +681,13 @@ def __init__(
timeout: float,
elapsed: Callable[[], float],
max_width: int = 80,
display: Callable = print,
):
self.timeout = timeout
self.elapsed = elapsed
self._max_width = max_width
self._status = "Authorization pending"
# TODO: use Unicode block elements for progress bar (https://en.wikipedia.org/wiki/Block_Elements)?
self._display = _like_print(display)
self._bar_chars = ("[", "#", "-", "]")

def _instructions(self, info: VerificationInfo) -> str:
Expand All @@ -687,7 +697,7 @@ def _instructions(self, info: VerificationInfo) -> str:
return f"Visit {info.verification_uri} and enter user code {info.user_code!r} to authenticate."

def show_instructions(self, info: VerificationInfo) -> None:
print(self._instructions(info=info))
self._display(self._instructions(info=info))

def set_status(self, status: str):
self._status = status
Expand All @@ -702,7 +712,7 @@ def show_progress(self, status: Optional[str] = None):
if status:
self.set_status(status)
col_width = (self._max_width - 1) // 2
print(f"{self._progress_bar(width=col_width)} {self._status[:col_width]}", end="\r")
self._display(f"{self._progress_bar(width=col_width)} {self._status[:col_width]}", end="\r")


class _JupyterDeviceCodePollUi(_BasicDeviceCodePollUi):
Expand Down Expand Up @@ -753,7 +763,7 @@ def __init__(
requests_session: Optional[requests.Session] = None,
):
super().__init__(client_info=client_info, requests_session=requests_session)
self._display = display # TODO: this is unused now
self._display = display
# Allow to specify/override device code URL for cases when it is not available in OIDC discovery doc.
self._device_code_url = device_code_url or self._provider_config.get("device_authorization_endpoint")
if not self._device_code_url:
Expand Down Expand Up @@ -819,7 +829,7 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
if in_jupyter_context():
poll_ui = _JupyterDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed)
else:
poll_ui = _BasicDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed)
poll_ui = _BasicDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed, display=self._display)
poll_ui.show_instructions(info=verification_info)

while elapsed() <= self._max_poll_time:
Expand Down

0 comments on commit 0c346a9

Please sign in to comment.