Skip to content

Commit

Permalink
test: add verify provider credentials test
Browse files Browse the repository at this point in the history
  • Loading branch information
taskingaijc authored and DynamesC committed Oct 12, 2024
1 parent 3f648c7 commit 67010b9
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
7 changes: 7 additions & 0 deletions inference/test/inference_service/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ async def verify_credentials(data: Dict):
return ResponseWrapper(response.status, await response.json())


async def verify_provider_credentials(data: Dict):
async with aiohttp.ClientSession() as session:
request_url = f"{Config.BASE_URL}/verify_provider_credentials"
response = await session.post(request_url, json=data)
return ResponseWrapper(response.status, await response.json())


async def caches():
async with aiohttp.ClientSession() as session:
request_url = f"{Config.BASE_URL}/caches"
Expand Down
45 changes: 43 additions & 2 deletions inference/test/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from test.utils.utils import generate_test_cases_for_validation, generate_wildcard_test_case_for_validation
from test.inference_service.inference import verify_credentials
from test.utils.utils import generate_test_cases_for_validation, generate_test_cases_for_provider_validation, generate_wildcard_test_case_for_validation
from test.inference_service.inference import verify_credentials, verify_provider_credentials
from app.models import provider_credentials
import pytest
import asyncio
Expand Down Expand Up @@ -74,6 +74,47 @@ async def test_wildcard_validation(self, test_data):
assert res.json()["status"] == "success", f"test_validation failed: result={res.json()}"
await asyncio.sleep(1.5)

@pytest.mark.parametrize("test_data", generate_test_cases_for_provider_validation(), ids=lambda d: d["provider_id"])
@pytest.mark.asyncio
@pytest.mark.test_id("inference_016-017")
async def test_provider_validation(self, test_data):
provider_id = test_data["provider_id"]
print("provider_id: ", provider_id)

credentials = {
key: provider_credentials.aes_decrypt(test_data["credentials"][key])
for key in test_data["credentials"].keys()
}

request_data = {"provider_id": provider_id, "credentials": credentials}
try:
res = await asyncio.wait_for(verify_provider_credentials(request_data), timeout=120)
except asyncio.TimeoutError:
pytest.skip("Skipping test due to timeout after 2 minutes.")
assert res.status_code == 200, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "success", f"test_validation failed: result={res.json()}"
await asyncio.sleep(1)

@pytest.mark.parametrize("test_data", generate_test_cases_for_provider_validation(), ids=lambda d: d["provider_id"])
@pytest.mark.asyncio
@pytest.mark.test_id("inference_016-017")
async def test_provider_validation_with_error_credential(self, test_data):
provider_id = test_data["provider_id"]
print("provider_id: ", provider_id)
if test_data["provider_id"] in ["debug", "custom_host", "openrouter", "replicate", "lm_studio", "ollama", "siliconcloud", "llama_api", "localai"]:
pytest.skip("Test not applicable for this provider")
credentials = {key: "12345678" for key in test_data["credentials"].keys()}

request_data = {"provider_id": provider_id, "credentials": credentials}
try:
res = await asyncio.wait_for(verify_provider_credentials(request_data), timeout=120)
except asyncio.TimeoutError:
pytest.skip("Skipping test due to timeout after 2 minutes.")
assert res.status_code == 400, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "PROVIDER_ERROR"
await asyncio.sleep(1)

@pytest.mark.parametrize("test_data", generate_test_cases_for_validation(), ids=lambda d: d["model_schema_id"])
@pytest.mark.asyncio
@pytest.mark.test_id("inference_018")
Expand Down
35 changes: 35 additions & 0 deletions inference/test/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"zhipu",
"baichuan",
"hugging_face",
"hugging_face_inference_endpoint",
"tongyi",
"wenxin",
"moonshot",
Expand Down Expand Up @@ -307,6 +308,40 @@ def generate_test_cases_for_validation():
return cases


def generate_test_cases_for_provider_validation():
providers_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../providers")
provider_ids = [name for name in os.listdir(providers_path) if os.path.isdir(os.path.join(providers_path, name))]
provider_ids = [name for name in provider_ids if not name.startswith("_") and not name.startswith("template")]

cases = []
for provider_id in provider_ids:
if provider_id in white_list_providers:
continue
if provider_id == "debug" and CONFIG.PROD:
continue
print("Adding test cases for provider: ", provider_id)
provider_path = os.path.join(providers_path, provider_id, "resources")

provider_yaml_data = load_yaml(os.path.join(provider_path, "provider.yml"))
if provider_yaml_data["provider_id"] in ["localai", "lm_studio", "ollama"]:
continue
pass_provider_level_credential_check = provider_yaml_data.get("pass_provider_level_credential_check")
# if pass_provider_level_credential_check:
# continue
credentials_schema = provider_yaml_data["credentials_schema"]

credentials = {
key: provider_credentials.aes_encrypt(os.environ.get(key)) for key in credentials_schema["required"]
}
base_test_case = {
"provider_id": provider_yaml_data["provider_id"],
"credentials": credentials,
}
cases.append(base_test_case)

return cases


def generate_wildcard_test_case_for_validation():
# load wildcard_test_cases.yml
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "wildcard_test_cases.yml")
Expand Down

0 comments on commit 67010b9

Please sign in to comment.