Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions src/picterra/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,26 @@
else:
from typing_extensions import Literal, TypedDict

from typing import Any, Generic, TypeVar
from typing import Any, Generic, Optional, TypeVar
from urllib.parse import urlencode, urljoin

import requests
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from urllib3.util.retry import Retry

from .utils.oauth import OAuthClient, OAuthError

logger = logging.getLogger()

CHUNK_SIZE_BYTES = 8192 # 8 KiB

# ANSI escape codes for colors
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m" # Resets the color to default

CLIENT_ID = "Eya1oJleyYoo35I17w5WWP2oTbTLr89LTJXWBxDs"

# allow injecting an non-existing package name to test the fallback behavior
# of _get_ua in tests (see test_headers_user_agent_version__fallback)
Expand Down Expand Up @@ -64,6 +73,7 @@ class _RequestsSession(requests.Session):

def __init__(self, *args, **kwargs):
self.timeout = kwargs.pop("timeout")
#self.auth = kwargs.pop("auth")
super().__init__(*args, **kwargs)
self.headers.update(
{
Expand All @@ -73,6 +83,7 @@ def __init__(self, *args, **kwargs):

def request(self, *args, **kwargs):
kwargs.setdefault("timeout", self.timeout)
#kwargs.setdefault("auth", self.auth)
return super().request(*args, **kwargs)


Expand Down Expand Up @@ -182,34 +193,75 @@ class FeatureCollection(TypedDict):
features: list[Feature]


class AuthInitError(Exception):
pass


class ApiKeyAuth(AuthBase):
api_key: str

def __init__(self):
if os.environ.get("PICTERRA_API_KEY", None) is None:
raise AuthInitError("PICTERRA_API_KEY environment variable not set")
self.api_key = os.environ.get("PICTERRA_API_KEY", None)

def __call__(self, r):
r.headers['X-Api-Key'] = self.api_key
return r


class Oauth2Auth(AuthBase):
oauth_token: dict

def __init__(self, base_url: str):
cl = OAuthClient(CLIENT_ID, base_url)
try:
data = cl.start()
self.oauth_token = data
print(f"{GREEN}Logged in at {base_url}.{RESET}")
except OAuthError as e:
raise SystemExit(f"{RED}Error during login: '{e}'{RESET}")

def __call__(self, r):
r.headers['Authorization'] = "Bearer " + self.oauth_token["token"]
return r


class BaseAPIClient:
"""
Base class for Picterra API clients.

This is subclassed for the different products we have.
"""
base_url: str
sess: _RequestsSession

def __init__(
self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10
self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10, auth: Literal["apikey", "oauth2"] = "apikey",
):
"""
Args:
api_url: the api's base url. This is different based on the Picterra product used
and is typically defined by implementations of this client
timeout: number of seconds before the request times out
max_retries: max attempts when ecountering gateway issues or throttles; see
max_retries: max attempts when encountering gateway issues or throttles; see
retry_strategy comment below
backoff_factor: factor used nin the backoff algorithm; see retry_strategy comment below
auth: TODO
"""
base_url = os.environ.get(
"PICTERRA_BASE_URL", "https://app.picterra.ch/"
)
api_key = os.environ.get("PICTERRA_API_KEY", None)
if not api_key:
raise APIError("PICTERRA_API_KEY environment variable is not defined")
if auth == "apikey":
auth_class = ApiKeyAuth()
elif auth == "oauth2":
auth_class = Oauth2Auth(base_url)
else:
raise RuntimeError("Invalid authentication method. Must be 'apikey' or 'oauth2'.")
logger.info(
"Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.",
"Using base_url=%s, auth=%s; api_url=%s; %d max retries, %d backoff and %s timeout.",
base_url,
auth,
api_url,
max_retries,
backoff_factor,
Expand All @@ -219,6 +271,7 @@ def __init__(
# Create the session with a default timeout (30 sec), that we can then
# override on a per-endpoint basis (will be disabled for file uploads and downloads)
self.sess = _RequestsSession(timeout=timeout)
self.sess.auth = auth_class
# Retry: we set the HTTP codes for our throttle (429) plus possible gateway problems (50*),
# and for polling methods (GET), as non-idempotent ones should be addressed via idempotency
# key mechanism; given the algorithm is {<backoff_factor> * (2 **<retries-1>}, and we
Expand All @@ -233,8 +286,6 @@ def __init__(
adapter = HTTPAdapter(max_retries=retry_strategy)
self.sess.mount("https://", adapter)
self.sess.mount("http://", adapter)
# Authentication
self.sess.headers.update({"X-Api-Key": api_key})

def _full_url(self, path: str, params: dict[str, Any] | None = None):
url = urljoin(self.base_url, path)
Expand Down Expand Up @@ -295,3 +346,4 @@ def get_operation_results(self, operation_id: str) -> dict[str, Any]:
self._full_url("operations/%s/" % operation_id),
)
return resp.json()["results"]

Empty file added src/picterra/utils/__init__.py
Empty file.
Loading
Loading