-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored client and added missing config methods (env) (#23)
- Loading branch information
1 parent
c25b2d6
commit 5610858
Showing
7 changed files
with
244 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
|
||
from vectara.factory import Factory | ||
from vectara.config import HomeConfigLoader, EnvConfigLoader, ApiKeyAuthConfig, OAuth2AuthConfig | ||
from pathlib import Path | ||
|
||
import unittest | ||
import logging | ||
|
||
class FactoryConfigTest(unittest.TestCase): | ||
""" | ||
This test depends on our YAML default config being defined. | ||
We use this to test various methods of injection. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
logging.basicConfig(format='%(asctime)s:%(name)-35s %(levelname)s:%(message)s', level=logging.INFO, | ||
datefmt='%H:%M:%S %z') | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
|
||
def _test_factory_auth(self, target: Factory, expected_method: str): | ||
client = target.build() | ||
self.assertEqual(expected_method, target.load_method) | ||
|
||
if not client.corpus_manager: | ||
raise Exception("Corpus manager should be defined") | ||
|
||
results = client.corpus_manager.find_corpora_with_filter("", 1) | ||
if results and len(results) > 0: | ||
self.logger.info(f"Found corpus [{results[0].key}]") | ||
|
||
|
||
def test_default_load(self): | ||
factory = Factory() | ||
self._test_factory_auth(factory, "path_home") | ||
|
||
def test_explicit_path(self): | ||
factory = Factory(config_path=str(Path.home().resolve())) | ||
self._test_factory_auth(factory, "path_explicit") | ||
|
||
def test_env(self): | ||
client_config = HomeConfigLoader().load() | ||
os.environ[EnvConfigLoader.ENV_CUSTOMER_ID] = client_config.customer_id | ||
if isinstance(client_config.auth, ApiKeyAuthConfig): | ||
os.environ[EnvConfigLoader.ENV_API_KEY] = client_config.auth.api_key | ||
elif isinstance(client_config.auth, OAuth2AuthConfig): | ||
os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_ID] = client_config.auth.app_client_id | ||
os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_SECRET] = client_config.auth.app_client_secret | ||
|
||
try: | ||
factory = Factory() | ||
self._test_factory_auth(factory, "env") | ||
finally: | ||
if isinstance(client_config.auth, ApiKeyAuthConfig): | ||
del os.environ[EnvConfigLoader.ENV_API_KEY] | ||
elif isinstance(client_config.auth, OAuth2AuthConfig): | ||
del os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_ID] | ||
del os.environ[EnvConfigLoader.ENV_OAUTH2_CLIENT_SECRET] | ||
|
||
def test_explicit_typed(self): | ||
client_config = HomeConfigLoader().load() | ||
factory = Factory(config=client_config) | ||
self._test_factory_auth(factory, "explicit_typed") | ||
|
||
def test_explicit_dict(self): | ||
client_config = HomeConfigLoader().load().model_dump() | ||
factory = Factory(config=client_config) | ||
self._test_factory_auth(factory, "explicit_dict") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,37 @@ | ||
from .base_client import BaseVectara, AsyncBaseVectara | ||
from vectara.managers.corpus import CorpusManager | ||
from vectara.managers.upload import UploadManager, UploadWrapper | ||
from vectara.utils import LabHelper | ||
|
||
class Vectara(BaseVectara): | ||
pass | ||
from typing import Union, Optional, Callable | ||
import logging | ||
|
||
class Vectara(BaseVectara): | ||
""" | ||
We extend the Vectara client, adding additional helper services. Due to the methodology used in the | ||
Vectara constructor, hard-coding dependencies and using an internal wrapper, we use lazy initialization | ||
for the helper classes like the CorpusManager. | ||
TODO Change Client to build dependencies inside constructor (harder to decouple, but removes optionality) | ||
""" | ||
|
||
def __init__(self, *args, | ||
**kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
self.corpus_manager: Union[None, CorpusManager] = None | ||
self.upload_manager: Union[None, UploadManager] = None | ||
self.lab_helper: Union[None, LabHelper] = None | ||
|
||
def set_corpus_manager(self, corpus_manager: CorpusManager) -> None: | ||
self.corpus_manager = corpus_manager | ||
|
||
def set_upload_manager(self, upload_manager: UploadManager) -> None: | ||
self.upload_manager = upload_manager | ||
|
||
def set_lab_helper(self, lab_helper: LabHelper) -> None: | ||
self.lab_helper = lab_helper | ||
|
||
class AsyncVectara(AsyncBaseVectara): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .config import (BaseConfigLoader, HomeConfigLoader, JsonConfigLoader, BaseAuthConfig, | ||
ClientConfig, OAuth2AuthConfig, ApiKeyAuthConfig) | ||
from .config import (BaseConfigLoader, HomeConfigLoader, BaseAuthConfig, | ||
ClientConfig, OAuth2AuthConfig, ApiKeyAuthConfig, EnvConfigLoader, PathConfigLoader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.