Skip to content

Commit

Permalink
Refactored client and added missing config methods (env) (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-vectara authored Oct 18, 2024
1 parent c25b2d6 commit 5610858
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 146 deletions.
9 changes: 5 additions & 4 deletions examples/01_getting_started/getting_started_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from vectara.factory import Factory, WrappedVectara
from vectara.factory import Factory
from vectara.managers import CreateCorpusRequest
from pathlib import Path
from vectara.types import StructuredDocument
from typing import List, Dict
from vectara.utils import LabHelper
from vectara import Vectara
import re
import logging

Expand All @@ -14,11 +15,11 @@ def __init__(self):
datefmt='%H:%M:%S %z')
self.logger = logging.getLogger(self.__class__.__name__)

def _check_initialized(self, client: WrappedVectara):
def _check_initialized(self, client: Vectara):
if not client.lab_helper:
raise Exception("Client not initialized correctly")

def setup_01(self, client: WrappedVectara) -> str:
def setup_01(self, client: Vectara) -> str:
self.logger.info("Setting up Lab 01")
self._check_initialized(client)
request = CreateCorpusRequest(name="Getting Started - Query API", key="01-getting-started-query-api")
Expand Down Expand Up @@ -49,7 +50,7 @@ def setup_01(self, client: WrappedVectara) -> str:
self.logger.info("Lab setup for 01 complete")
return response.key

def setup_02(self, client: WrappedVectara) -> str:
def setup_02(self, client: Vectara) -> str:
self.logger.info("Setting up Lab 02")
self._check_initialized(client) # type: ignore
request = CreateCorpusRequest(name="Getting Started - Index API", key="02-getting-started-index-api")
Expand Down
74 changes: 74 additions & 0 deletions int_tests/vectara_int_tests/managers/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os

from vectara.factory import Factory
from vectara.config import HomeConfigLoader, EnvConfigLoader, ApiKeyAuthConfig, OAuth2AuthConfig
from pathlib import Path

import unittest
import logging

class FactoryConfigTest(unittest.TestCase):
"""
This test depends on our YAML default config being defined.
We use this to test various methods of injection.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logging.basicConfig(format='%(asctime)s:%(name)-35s %(levelname)s:%(message)s', level=logging.INFO,
datefmt='%H:%M:%S %z')
self.logger = logging.getLogger(self.__class__.__name__)

def _test_factory_auth(self, target: Factory, expected_method: str):
client = target.build()
self.assertEqual(expected_method, target.load_method)

if not client.corpus_manager:
raise Exception("Corpus manager should be defined")

results = client.corpus_manager.find_corpora_with_filter("", 1)
if results and len(results) > 0:
self.logger.info(f"Found corpus [{results[0].key}]")


def test_default_load(self):
factory = Factory()
self._test_factory_auth(factory, "path_home")

def test_explicit_path(self):
factory = Factory(config_path=str(Path.home().resolve()))
self._test_factory_auth(factory, "path_explicit")

def test_env(self):
client_config = HomeConfigLoader().load()
os.environ[EnvConfigLoader.ENV_CUSTOMER_ID] = client_config.customer_id
if isinstance(client_config.auth, ApiKeyAuthConfig):
os.environ[EnvConfigLoader.ENV_API_KEY] = client_config.auth.api_key
elif isinstance(client_config.auth, OAuth2AuthConfig):
os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_ID] = client_config.auth.app_client_id
os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_SECRET] = client_config.auth.app_client_secret

try:
factory = Factory()
self._test_factory_auth(factory, "env")
finally:
if isinstance(client_config.auth, ApiKeyAuthConfig):
del os.environ[EnvConfigLoader.ENV_API_KEY]
elif isinstance(client_config.auth, OAuth2AuthConfig):
del os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_ID]
del os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_SECRET]

def test_explicit_typed(self):
client_config = HomeConfigLoader().load()
factory = Factory(config=client_config)
self._test_factory_auth(factory, "explicit_typed")

def test_explicit_dict(self):
client_config = HomeConfigLoader().load().model_dump()
factory = Factory(config=client_config)
self._test_factory_auth(factory, "explicit_dict")


if __name__ == '__main__':
unittest.main()
34 changes: 32 additions & 2 deletions src/vectara/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
from .base_client import BaseVectara, AsyncBaseVectara
from vectara.managers.corpus import CorpusManager
from vectara.managers.upload import UploadManager, UploadWrapper
from vectara.utils import LabHelper

class Vectara(BaseVectara):
pass
from typing import Union, Optional, Callable
import logging

class Vectara(BaseVectara):
"""
We extend the Vectara client, adding additional helper services. Due to the methodology used in the
Vectara constructor, hard-coding dependencies and using an internal wrapper, we use lazy initialization
for the helper classes like the CorpusManager.
TODO Change Client to build dependencies inside constructor (harder to decouple, but removes optionality)
"""

def __init__(self, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.logger = logging.getLogger(self.__class__.__name__)
self.corpus_manager: Union[None, CorpusManager] = None
self.upload_manager: Union[None, UploadManager] = None
self.lab_helper: Union[None, LabHelper] = None

def set_corpus_manager(self, corpus_manager: CorpusManager) -> None:
self.corpus_manager = corpus_manager

def set_upload_manager(self, upload_manager: UploadManager) -> None:
self.upload_manager = upload_manager

def set_lab_helper(self, lab_helper: LabHelper) -> None:
self.lab_helper = lab_helper

class AsyncVectara(AsyncBaseVectara):
pass
4 changes: 2 additions & 2 deletions src/vectara/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .config import (BaseConfigLoader, HomeConfigLoader, JsonConfigLoader, BaseAuthConfig,
ClientConfig, OAuth2AuthConfig, ApiKeyAuthConfig)
from .config import (BaseConfigLoader, HomeConfigLoader, BaseAuthConfig,
ClientConfig, OAuth2AuthConfig, ApiKeyAuthConfig, EnvConfigLoader, PathConfigLoader)
111 changes: 69 additions & 42 deletions src/vectara/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Annotated
import json
import yaml
from os import path, sep
from os import path, sep, getenv
from pathlib import Path

"""
Expand Down Expand Up @@ -77,30 +77,19 @@ class ClientConfig(BaseModel):
Discriminator(get_discriminator_value),
]

def loadConfig(config: str) -> ClientConfig:
"""
Loads our configuration from JSON onto our data classes.
:param config: the input configuration in JSON format.
:return: the parsed client configuration1
:raises TypeError: if the configuration cannot be parsed correctly
"""
logger.info(f"Loading config from {config}")
return ClientConfig.model_validate_json(config)
CONFIG_FILE_NAME = ".vec_auth.yaml"
DEFAULT_CONFIG_NAME = "default"

class BaseConfigLoader(ABC):
CONFIG_FILE_NAME = ".vec_auth.yaml"

DEFAULT_CONFIG_NAME = "default"

def __init__(self, profile: Union[str, None] = DEFAULT_CONFIG_NAME):
self.logger = logging.getLogger(self.__class__.__name__)
if not profile:
self.profile = self.DEFAULT_CONFIG_NAME
self.profile = DEFAULT_CONFIG_NAME
else:
self.profile = profile

def load(self):
def load(self) -> ClientConfig:
"""
:return: The client configuration in our domain class ClientConfig
:raises TypeError: Should be implemented in subclasses.
Expand All @@ -120,63 +109,101 @@ def _load_yaml_config(self, final_config_path):
self.logger.info(f"Loading specified profile [{self.profile}]")
profile_to_load = self.profile
else:
self.logger.info(f"Loading default configuration [{BaseConfigLoader.DEFAULT_CONFIG_NAME}]")
profile_to_load = BaseConfigLoader.DEFAULT_CONFIG_NAME
self.logger.info(f"Loading default configuration [{DEFAULT_CONFIG_NAME}]")
profile_to_load = DEFAULT_CONFIG_NAME


if profile_to_load in creds:
return creds[profile_to_load]
else:
raise TypeError(f"Specified profile [{profile_to_load}] not found in [{final_config_path}]")

def _save(self, client_config: ClientConfig, final_config_path: str):
def _save(self, client_config: ClientConfig, final_config_path: Path):
creds = self._load_yaml_full(final_config_path)
creds[self.profile] = client_config.model_dump()

with open(final_config_path, 'w') as yaml_stream:
yaml.safe_dump(creds, yaml_stream)

def _delete(self, final_config_path: str):
def _delete(self, final_config_path: Path):
creds = self._load_yaml_full(final_config_path)
del creds[self.profile]

with open(final_config_path, 'w') as yaml_stream:
yaml.safe_dump(creds, yaml_stream)


class JsonConfigLoader(BaseConfigLoader):
class EnvConfigLoader:

"""
Loads our configuration from JSON
The environment variable containing the customer id
"""
ENV_CUSTOMER_ID = "VECTARA_CUSTOMER_ID"

def __init__(self, config_json: str, profile: Union[str, None] = BaseConfigLoader.DEFAULT_CONFIG_NAME):
super().__init__(profile=profile)
self.config_json = config_json
"""
The environment variable containing the API key
"""
ENV_API_KEY = "VECTARA_API_KEY"

def load(self):
self.logger.info("Loading configuration from JSON string")
config_dict = json.loads(self.config_json)
return ClientConfig.model_validate(config_dict)
"""
The environment variable containing the OAuth2 Client ID
"""
ENV_OAUTH2_CLIENT_ID = "VECTARA_CLIENT_ID"

"""
The environment variable containing the OAuth2 Client Secret
"""
ENV_OAUTH2_CLIENT_SECRET = "VECTARA_CLIENT_SECRET"

def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)

def load(self) -> Optional[ClientConfig]:
api_key = getenv(self.ENV_API_KEY)
oauth2_client_id = getenv(self.ENV_OAUTH2_CLIENT_ID)
oauth2_client_secret = getenv(self.ENV_OAUTH2_CLIENT_SECRET)
customer_id = getenv(self.ENV_CUSTOMER_ID)

if api_key or (oauth2_client_id and oauth2_client_secret):

if not customer_id:
self.logger.warning("No customer ID, using foo value as placeholder since deprecated for V2 API")
customer_id = "foo"
if api_key:
return ClientConfig.model_validate({
"customer_id": customer_id,
"auth": {"api_key": api_key}
})
else:
return ClientConfig.model_validate({
"customer_id": customer_id,
"auth": {"app_client_id": oauth2_client_id, "app_client_secret": oauth2_client_secret}
})
else:
return None


class PathConfigLoader(BaseConfigLoader):
"""
Loads our configuration from the specified folder/file
"""

def __init__(self, config_path : str, profile: Union[str, None] = BaseConfigLoader.DEFAULT_CONFIG_NAME):
def __init__(self, config_path: Union[str, Path], profile: Union[str, None] = DEFAULT_CONFIG_NAME):
super().__init__(profile=profile)

self.config_path = config_path
if isinstance(config_path, Path):
self.config_path = config_path
else:
self.config_path = Path(config_path)

def load(self):
self.logger.info(f"Loading configuration from path {self.config_path}")

if path.exists(self.config_path) and path.isdir(self.config_path):
self.logger.info(f"Configuration param is a path, looking for {BaseConfigLoader.CONFIG_FILE_NAME}")
looking_for = self.config_path / BaseConfigLoader.CONFIG_FILE_NAME
self.logger.info(f"Configuration param is a path, looking for {CONFIG_FILE_NAME}")
looking_for = self.config_path / CONFIG_FILE_NAME
if not path.exists(looking_for) or not path.isfile(looking_for):
raise TypeError(f"Unable to find configuration file [{BaseConfigLoader.CONFIG_FILE_NAME}]"
raise TypeError(f"Unable to find configuration file [{CONFIG_FILE_NAME}]"
f" within specified directory [{self.config_path}]")
elif path.exists(self.config_path) and path.isfile(self.config_path):
self.logger.info(f"Configuration param is a file")
Expand All @@ -188,23 +215,23 @@ def load(self):
return ClientConfig.model_validate(config_dict)



class HomeConfigLoader(BaseConfigLoader):
"""
Loads our configuration from the users home directory
"""

def __init__(self, profile: Union[str, None] = BaseConfigLoader.DEFAULT_CONFIG_NAME):
def __init__(self, profile: Union[str, None] = DEFAULT_CONFIG_NAME):
super().__init__(profile=profile)

def _build_config_path(self) -> str:
home = str(Path.home())
self.logger.info(f"Loading configuration from users home directory [{home}]")
def _build_config_path(self) -> Path:
home_path = Path.home()
home_str = str(home_path.resolve())
self.logger.info(f"Loading configuration from users home directory [{home_str}]")

looking_for = home + sep + BaseConfigLoader.CONFIG_FILE_NAME
looking_for = home_path / CONFIG_FILE_NAME
if not path.exists(looking_for) or not path.isfile(looking_for):
raise TypeError(f"Unable to find configuration file [{BaseConfigLoader.CONFIG_FILE_NAME}]"
f" within home directory [{home}]")
raise TypeError(f"Unable to find configuration file [{CONFIG_FILE_NAME}]"
f" within home directory [{home_str}]")
return looking_for

def load(self):
Expand Down
Loading

0 comments on commit 5610858

Please sign in to comment.