Skip to content

feat(sdk): allow for multi config #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ export STAX_ACCESS_KEY=<your_access_key>
export STAX_SECRET_KEY=<your_secret_key>
```

##### Client Auth Configuration
You can configure each client individually by passing in a config on init.
When a client is created it's configuration will be locked in and any change to the configurations will not affect the client.

This can be seen in our [guide](https://github.com/stax-labs/lib-stax-python-sdk/blob/master/examples/auth.py).

*Optional configuration:*

##### Authentication token expiry
Expand Down
33 changes: 33 additions & 0 deletions examples/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import os

from staxapp.config import Config
from staxapp.openapi import StaxClient
from customscript import get_hostnames
hostname = get_hostnames()

access_key = os.getenv("STAX_ACCESS_KEY")
secret_key = os.getenv("STAX_SECRET_KEY")

Config.hostname = hostname["au1"]
Config.access_key = access_key
Config.secret_key = secret_key

accounts_au1 = StaxClient('accounts')

au1_response = accounts_au1.CreateAccountType(
Name="sdk-au1"
)

print(json.dumps(au1_response, indent=4, sort_keys=True))

access_key_2 = os.getenv("STAX_ACCESS_KEY_2")
secret_key_2 = os.getenv("STAX_SECRET_KEY_2")
config = Config(hostname=hostname["us1"], access_key=access_key_2, secret_key=secret_key_2)

us1_accounts = StaxClient('accounts', config=config)

us1_response = us1_accounts.CreateAccountType(
Name="sdk-us1"
)
print(json.dumps(us1_response, indent=4, sort_keys=True))
50 changes: 26 additions & 24 deletions staxapp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@


class Api:
_requests_auth = None
@classmethod
def get_config(cls, config=None):
if config is None:
config = Config.GetDefaultConfig()
config.init()
return config

@classmethod
def _headers(cls, custom_headers) -> dict:
Expand All @@ -17,12 +22,6 @@ def _headers(cls, custom_headers) -> dict:
}
return headers

@classmethod
def _auth(cls, **kwargs):
if not cls._requests_auth:
cls._requests_auth = Config.get_auth_class().requests_auth
return cls._requests_auth(Config.access_key, Config.secret_key, **kwargs)

@staticmethod
def handle_api_response(response):
try:
Expand All @@ -31,13 +30,13 @@ def handle_api_response(response):
raise ApiException(str(e), response)

@classmethod
def get(cls, url_frag, params={}, **kwargs):
url_frag = url_frag.replace(f"/{Config.API_VERSION}", "")
url = f"{Config.api_base_url()}/{url_frag.lstrip('/')}"

def get(cls, url_frag, params={}, config=None, **kwargs):
config = cls.get_config(config)
url_frag = url_frag.replace(f"/{config.API_VERSION}", "")
url = f"{config.api_base_url()}/{url_frag.lstrip('/')}"
response = requests.get(
url,
auth=cls._auth(),
auth=config._auth(),
params=params,
headers=cls._headers(kwargs.get("headers", {})),
**kwargs,
Expand All @@ -46,43 +45,46 @@ def get(cls, url_frag, params={}, **kwargs):
return response.json()

@classmethod
def post(cls, url_frag, payload={}, **kwargs):
url_frag = url_frag.replace(f"/{Config.API_VERSION}", "")
url = f"{Config.api_base_url()}/{url_frag.lstrip('/')}"
def post(cls, url_frag, payload={}, config=None, **kwargs):
config = cls.get_config(config)
url_frag = url_frag.replace(f"/{config.API_VERSION}", "")
url = f"{config.api_base_url()}/{url_frag.lstrip('/')}"

response = requests.post(
url,
json=payload,
auth=cls._auth(),
auth=config._auth(),
headers=cls._headers(kwargs.get("headers", {})),
**kwargs,
)
cls.handle_api_response(response)
return response.json()

@classmethod
def put(cls, url_frag, payload={}, **kwargs):
url_frag = url_frag.replace(f"/{Config.API_VERSION}", "")
url = f"{Config.api_base_url()}/{url_frag.lstrip('/')}"
def put(cls, url_frag, payload={}, config=None, **kwargs):
config = cls.get_config(config)
url_frag = url_frag.replace(f"/{config.API_VERSION}", "")
url = f"{config.api_base_url()}/{url_frag.lstrip('/')}"

response = requests.put(
url,
json=payload,
auth=cls._auth(),
auth=config._auth(),
headers=cls._headers(kwargs.get("headers", {})),
**kwargs,
)
cls.handle_api_response(response)
return response.json()

@classmethod
def delete(cls, url_frag, params={}, **kwargs):
url_frag = url_frag.replace(f"/{Config.API_VERSION}", "")
url = f"{Config.api_base_url()}/{url_frag.lstrip('/')}"
def delete(cls, url_frag, params={}, config=None, **kwargs):
config = cls.get_config(config)
url_frag = url_frag.replace(f"/{config.API_VERSION}", "")
url = f"{config.api_base_url()}/{url_frag.lstrip('/')}"

response = requests.delete(
url,
auth=cls._auth(),
auth=config._auth(),
params=params,
headers=cls._headers(kwargs.get("headers", {})),
**kwargs,
Expand Down
44 changes: 25 additions & 19 deletions staxapp/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@


class StaxAuth:
def __init__(self, config_branch, max_retries: int = 3):
config = StaxConfig.api_config

self.identity_pool = config.get(config_branch).get("identityPoolId")
self.user_pool = config.get(config_branch).get("userPoolId")
self.client_id = config.get(config_branch).get("userPoolWebClientId")
self.aws_region = config.get(config_branch).get("region")
def __init__(self, config_branch: str, config: StaxConfig, max_retries: int = 3):
self.config = config
api_config = self.config.api_config
self.identity_pool = api_config.get(config_branch).get("identityPoolId")
self.user_pool = api_config.get(config_branch).get("userPoolId")
self.client_id = api_config.get(config_branch).get("userPoolWebClientId")
self.aws_region = api_config.get(config_branch).get("region")
self.max_retries = max_retries

def requests_auth(self, username, password, **kwargs):
def requests_auth(self, **kwargs):
username = self.config.access_key
password = self.config.secret_key
if username is None:
raise InvalidCredentialsException(
"Please provide an Access Key to your config"
Expand All @@ -37,10 +39,10 @@ def requests_auth(self, username, password, **kwargs):
id_creds = self.sts_from_cognito_identity_pool(id_token, **kwargs)
auth = self.sigv4_signed_auth_headers(id_creds)

StaxConfig.expiration = id_creds.get("Credentials").get("Expiration")
StaxConfig.auth = auth
self.config.expiration = id_creds.get("Credentials").get("Expiration")
self.config.auth = auth

return StaxConfig.auth
return self.config.auth

def id_token_from_cognito(
self, username=None, password=None, srp_client=None, **kwargs
Expand All @@ -59,6 +61,7 @@ def id_token_from_cognito(
client_id=self.client_id,
client=srp_client,
)

try:
tokens = aws.authenticate_user()
except ClientError as e:
Expand All @@ -69,7 +72,7 @@ def id_token_from_cognito(
elif e.response["Error"]["Code"] == "UserNotFoundException":
raise InvalidCredentialsException(
message=str(e),
detail="Please check your Access Key, that you have created your Api Token and that you are using the right STAX REGION",
detail=f"Please check your Access Key, that you have created your Api Token and that you are using the right STAX REGION",
)
else:
raise InvalidCredentialsException(
Expand Down Expand Up @@ -121,7 +124,7 @@ def sigv4_signed_auth_headers(self, id_creds):
aws_access_key=id_creds.get("Credentials").get("AccessKeyId"),
aws_secret_access_key=id_creds.get("Credentials").get("SecretKey"),
aws_token=id_creds.get("Credentials").get("SessionToken"),
aws_host=f"{StaxConfig.hostname}",
aws_host=f"{self.config.hostname}",
aws_region=self.aws_region,
aws_service="execute-api",
)
Expand All @@ -133,17 +136,20 @@ class RootAuth:
def requests_auth(username, password, **kwargs):
if StaxConfig.expiration and StaxConfig.expiration > datetime.now(timezone.utc):
return StaxConfig.auth

return StaxAuth("JumaAuth").requests_auth(username, password, **kwargs)
config = StaxConfig.GetDefaultConfig()
config.init()
config.access_key = username
config.secret_key = password
return StaxAuth("JumaAuth", config).requests_auth(**kwargs)


class ApiTokenAuth:
@staticmethod
def requests_auth(username, password, **kwargs):
def requests_auth(config: StaxConfig, **kwargs):
# Minimize the potentical for token to expire while still being used for auth (say within a lambda function)
if StaxConfig.expiration and StaxConfig.expiration - timedelta(
if config.expiration and config.expiration - timedelta(
minutes=int(environ.get("TOKEN_EXPIRY_THRESHOLD_IN_MINS", 1))
) > datetime.now(timezone.utc):
return StaxConfig.auth
return config.auth

return StaxAuth("ApiAuth").requests_auth(username, password, **kwargs)
return StaxAuth("ApiAuth", config).requests_auth(**kwargs)
59 changes: 39 additions & 20 deletions staxapp/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os
import platform as sysinfo
from distutils.command.config import config
from email.policy import default

import requests

Expand All @@ -18,26 +20,34 @@ class Config:
STAX_REGION = os.getenv("STAX_REGION", "au1.staxapp.cloud")
API_VERSION = "20190206"

cached_api_config = dict()
api_config = dict()
access_key = None
secret_key = None
auth_class = None
auth = None
_requests_auth = None
_initialized = False
base_url = None
hostname = f"api.{STAX_REGION}"
org_id = None
auth = None
expiration = None
load_live_schema = True

platform = sysinfo.platform()
python_version = sysinfo.python_version()
sdk_version = staxapp.__version__

def set_config(self):
self.base_url = f"https://{self.hostname}/{self.API_VERSION}"
config_url = f"{self.api_base_url()}/public/config"
if config_url == self.cached_api_config.get("caching"):
self.api_config = self.cached_api_config
else:
self.api_config = Config.get_api_config(config_url)

@classmethod
def set_config(cls):
cls.base_url = f"https://{cls.hostname}/{cls.API_VERSION}"
config_url = f"{cls.api_base_url()}/public/config"
def get_api_config(cls, config_url):
config_response = requests.get(config_url)
try:
config_response.raise_for_status()
Expand All @@ -46,30 +56,42 @@ def set_config(cls):
raise ApiException(
str(e), config_response, detail=" Could not load API config."
)

cls.api_config = config_response.json()

@classmethod
def init(cls, config=None):
if cls._initialized:
cls.cached_api_config = config_response.json()
cls.cached_api_config["caching"] = config_url
return config_response.json()

def __init__(self, hostname=None, access_key=None, secret_key=None):
if hostname is not None:
self.hostname = hostname
self.access_key = access_key
self.secret_key = secret_key

def init(self):
if self._initialized:
return
self.set_config()

if not config:
cls.set_config()
self._initialized = True

cls._initialized = True
def _auth(self, **kwargs):
if not self._requests_auth:
self._requests_auth = self.get_auth_class().requests_auth
return self._requests_auth(self, **kwargs)

@classmethod
def api_base_url(cls):
return cls.base_url
def api_base_url(self):
return self.base_url

@classmethod
def GetDefaultConfig(cls):
config = Config(Config.hostname, Config.access_key, Config.secret_key)
return config

def branch(cls):
return os.getenv("STAX_BRANCH", "master")

@classmethod
def schema_url(cls):
return f"{cls.base_url}/public/api-document"
return f"https://{cls.hostname}/{cls.API_VERSION}/public/api-document"

@classmethod
def get_auth_class(cls):
Expand All @@ -78,6 +100,3 @@ def get_auth_class(cls):

cls.auth_class = ApiTokenAuth
return cls.auth_class


Config.init()
2 changes: 1 addition & 1 deletion staxapp/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def validate(cls, data, component):
@staticmethod
def default_swagger_template() -> dict:
# Get the default swagger template from https://api.au1.staxapp.cloud/20190206/public/api-document
schema_response = requests.get(Config.schema_url()).json()
schema_response = requests.get(Config.GetDefaultConfig().schema_url()).json()
template = dict(
openapi="3.0.0",
info={
Expand Down
Loading