diff --git a/tests/integration/checks.py b/tests/integration/checks.py index 8ca964e6..c857936b 100644 --- a/tests/integration/checks.py +++ b/tests/integration/checks.py @@ -29,7 +29,8 @@ def load(test_data: dict) -> List[BaseCheck]: checks.append(ContainsCheck(test_name)) if test_data.get(DoesNotContainCheck.KEY): checks.append(DoesNotContainCheck(test_name)) - + if test_data.get(CodeGateEnrichment.KEY) is not None: + checks.append(CodeGateEnrichment(test_name)) return checks @@ -51,11 +52,10 @@ async def run_check(self, parsed_response: str, test_data: dict) -> bool: similarity = await self._calculate_string_similarity( parsed_response, test_data[DistanceCheck.KEY] ) + logger.debug(f"Similarity: {similarity}") + logger.debug(f"Response: {parsed_response}") + logger.debug(f"Expected Response: {test_data[DistanceCheck.KEY]}") if similarity < 0.8: - logger.error(f"Test {self.test_name} failed") - logger.error(f"Similarity: {similarity}") - logger.error(f"Response: {parsed_response}") - logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}") return False return True @@ -64,10 +64,9 @@ class ContainsCheck(BaseCheck): KEY = "contains" async def run_check(self, parsed_response: str, test_data: dict) -> bool: + logger.debug(f"Response: {parsed_response}") + logger.debug(f"Expected Response to contain: {test_data[ContainsCheck.KEY]}") if test_data[ContainsCheck.KEY].strip() not in parsed_response: - logger.error(f"Test {self.test_name} failed") - logger.error(f"Response: {parsed_response}") - logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'") return False return True @@ -76,11 +75,33 @@ class DoesNotContainCheck(BaseCheck): KEY = "does_not_contain" async def run_check(self, parsed_response: str, test_data: dict) -> bool: + logger.debug(f"Response: {parsed_response}") + logger.debug(f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'") if test_data[DoesNotContainCheck.KEY].strip() in parsed_response: - logger.error(f"Test {self.test_name} failed") - logger.error(f"Response: {parsed_response}") - logger.error( - f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'" - ) return False return True + + +class CodeGateEnrichment(BaseCheck): + KEY = "codegate_enrichment" + + async def run_check(self, parsed_response: str, test_data: dict) -> bool: + direct_response = test_data["direct_response"] + logger.debug(f"Response (CodeGate): {parsed_response}") + logger.debug(f"Response (Raw model): {direct_response}") + + # Use the DistanceCheck to compare the two responses + distance_check = DistanceCheck(self.test_name) + are_similar = await distance_check.run_check( + parsed_response, {DistanceCheck.KEY: direct_response} + ) + + # Check if the response is enriched by CodeGate. + # If it is, there should be a difference in the similarity score. + expect_enrichment = test_data.get(CodeGateEnrichment.KEY).get("expect_difference", False) + if expect_enrichment: + logger.info("CodeGate enrichment check: Expecting difference") + return not are_similar + # If the response is not enriched, the similarity score should be the same. + logger.info("CodeGate enrichment check: Not expecting difference") + return are_similar diff --git a/tests/integration/integration_tests.py b/tests/integration/integration_tests.py index efc2e105..caa1388c 100644 --- a/tests/integration/integration_tests.py +++ b/tests/integration/integration_tests.py @@ -9,13 +9,26 @@ import requests import structlog import yaml -from checks import CheckLoader +from checks import CheckLoader, CodeGateEnrichment from dotenv import find_dotenv, load_dotenv from requesters import RequesterFactory logger = structlog.get_logger("codegate") +# call_directly is a function to call the model directly bypassing codegate +def call_directly(url: str, headers: dict, data: dict) -> Optional[requests.Response]: + try: + headers["Content-Type"] = "application/json" + stream = data.get("stream", False) + response = requests.post(url, headers=headers, json=data, stream=stream) + response.raise_for_status() + return response + except Exception as e: + logger.error(f"Error making direct request to {url}: {str(e)}") + return None + + class CodegateTestRunner: def __init__(self): self.requester_factory = RequesterFactory() @@ -132,18 +145,27 @@ def replacement(match): async def run_test(self, test: dict, test_headers: dict) -> bool: test_name = test["name"] - url = test["url"] data = json.loads(test["data"]) streaming = data.get("stream", False) provider = test["provider"] - logger.info(f"Starting test: {test_name}") - response = self.call_codegate(url, test_headers, data, provider) + # Call Codegate + response = self.call_codegate(test["url"], test_headers, data, provider) if not response: logger.error(f"Test {test_name} failed: No response received") return False + # Call model directly if specified + direct_response = None + if test.get(CodeGateEnrichment.KEY) is not None: + direct_response = call_directly( + test.get(CodeGateEnrichment.KEY)["provider_url"], test_headers, data + ) + if not direct_response: + logger.error(f"Test {test_name} failed: No direct response received") + return False + # Debug response info logger.debug(f"Response status: {response.status_code}") logger.debug(f"Response headers: {dict(response.headers)}") @@ -152,13 +174,24 @@ async def run_test(self, test: dict, test_headers: dict) -> bool: parsed_response = self.parse_response_message(response, streaming=streaming) logger.debug(f"Response message: {parsed_response}") + if direct_response: + # Dirty hack to pass direct response to checks + test["direct_response"] = self.parse_response_message( + direct_response, streaming=streaming + ) + logger.debug(f"Direct response message: {test['direct_response']}") + # Load appropriate checks for this test checks = CheckLoader.load(test) # Run all checks all_passed = True for check in checks: + logger.info(f"Running check: {check.__class__.__name__}") passed_check = await check.run_check(parsed_response, test) + logger.info( + f"Check {check.__class__.__name__} {'passed' if passed_check else 'failed'}" + ) if not passed_check: all_passed = False diff --git a/tests/integration/ollama/testcases.yaml b/tests/integration/ollama/testcases.yaml index 9931ecdb..56a13b57 100644 --- a/tests/integration/ollama/testcases.yaml +++ b/tests/integration/ollama/testcases.yaml @@ -31,6 +31,9 @@ testcases: name: Ollama Chat provider: ollama url: http://127.0.0.1:8989/ollama/chat/completions + codegate_enrichment: + provider_url: http://127.0.0.1:11434/api/chat + expect_difference: false data: | { "max_tokens":4096, @@ -55,6 +58,9 @@ testcases: name: Ollama FIM provider: ollama url: http://127.0.0.1:8989/ollama/api/generate + codegate_enrichment: + provider_url: http://127.0.0.1:11434/api/generate + expect_difference: false data: | { "stream": true, @@ -88,6 +94,9 @@ testcases: name: Ollama Malicious Package provider: ollama url: http://127.0.0.1:8989/ollama/chat/completions + codegate_enrichment: + provider_url: http://127.0.0.1:11434/api/chat + expect_difference: true data: | { "max_tokens":4096, @@ -112,6 +121,9 @@ testcases: name: Ollama secret redacting chat provider: ollama url: http://127.0.0.1:8989/ollama/chat/completions + codegate_enrichment: + provider_url: http://127.0.0.1:11434/api/chat + expect_difference: true data: | { "max_tokens":4096, diff --git a/tests/integration/vllm/testcases.yaml b/tests/integration/vllm/testcases.yaml index bb446ced..52df9598 100644 --- a/tests/integration/vllm/testcases.yaml +++ b/tests/integration/vllm/testcases.yaml @@ -31,6 +31,9 @@ testcases: name: VLLM Chat provider: vllm url: http://127.0.0.1:8989/vllm/chat/completions + codegate_enrichment: + provider_url: http://127.0.0.1:8000/v1/chat/completions + expect_difference: false data: | { "max_tokens":4096, @@ -55,6 +58,10 @@ testcases: name: VLLM FIM provider: vllm url: http://127.0.0.1:8989/vllm/completions +# This is commented out for now as there's some issue with parsing the streamed response from the model (on the vllm side, not codegate) +# codegate_enrichment: +# provider_url: http://127.0.0.1:8000/v1/completions +# expect_difference: false data: | { "model": "Qwen/Qwen2.5-Coder-0.5B-Instruct", @@ -84,6 +91,9 @@ testcases: name: VLLM Malicious Package provider: vllm url: http://127.0.0.1:8989/vllm/chat/completions + codegate_enrichment: + provider_url: http://127.0.0.1:8000/v1/chat/completions + expect_difference: true data: | { "max_tokens":4096,