Skip to content
16 changes: 11 additions & 5 deletions codebase_rag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def format_missing_api_key_errors(
return error_msg


LOCAL_PROVIDERS = frozenset({cs.Provider.OLLAMA, cs.Provider.LOCAL, cs.Provider.VLLM})


@dataclass
class ModelConfig:
provider: str
Expand All @@ -102,7 +105,7 @@ class ModelConfig:
endpoint: str | None = None
project_id: str | None = None
region: str | None = None
provider_type: str | None = None
provider_type: cs.GoogleProviderType | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ModelConfig dataclass is intended to be a generic container for various providers. Typing provider_type strictly as cs.GoogleProviderType makes this field Google-specific. If other providers (e.g., AWS Bedrock) were to use this field in the future, it would require updating this enum. Consider keeping this as str | None for flexibility, while still using the enum for comparison in the validation logic.

Suggested change
provider_type: cs.GoogleProviderType | None = None
provider_type: str | None = None

thinking_budget: int | None = None
service_account_file: str | None = None

Expand All @@ -113,8 +116,11 @@ def to_update_kwargs(self) -> ModelConfigKwargs:
return ModelConfigKwargs(**result)

def validate_api_key(self, role: str = cs.DEFAULT_MODEL_ROLE) -> None:
local_providers = {cs.Provider.OLLAMA, cs.Provider.LOCAL, cs.Provider.VLLM}
if self.provider.lower() in local_providers:
provider_lower = self.provider.lower()
if provider_lower in LOCAL_PROVIDERS or (
provider_lower == cs.Provider.GOOGLE
and self.provider_type == cs.GoogleProviderType.VERTEX
):
return
if (
not self.api_key
Expand Down Expand Up @@ -150,7 +156,7 @@ class AppConfig(BaseSettings):
ORCHESTRATOR_ENDPOINT: str | None = None
ORCHESTRATOR_PROJECT_ID: str | None = None
ORCHESTRATOR_REGION: str = cs.DEFAULT_REGION
ORCHESTRATOR_PROVIDER_TYPE: str | None = None
ORCHESTRATOR_PROVIDER_TYPE: cs.GoogleProviderType | None = None
ORCHESTRATOR_THINKING_BUDGET: int | None = None
ORCHESTRATOR_SERVICE_ACCOUNT_FILE: str | None = None

Expand All @@ -160,7 +166,7 @@ class AppConfig(BaseSettings):
CYPHER_ENDPOINT: str | None = None
CYPHER_PROJECT_ID: str | None = None
CYPHER_REGION: str = cs.DEFAULT_REGION
CYPHER_PROVIDER_TYPE: str | None = None
CYPHER_PROVIDER_TYPE: cs.GoogleProviderType | None = None
CYPHER_THINKING_BUDGET: int | None = None
CYPHER_SERVICE_ACCOUNT_FILE: str | None = None

Expand Down
62 changes: 58 additions & 4 deletions codebase_rag/tests/test_github_issues_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from unittest.mock import patch

import pytest

from codebase_rag.config import AppConfig
from codebase_rag.constants import GoogleProviderType


class TestGitHubIssuesIntegration:
Expand Down Expand Up @@ -142,9 +145,6 @@ def test_openai_compatible_endpoints(self) -> None:
assert orchestrator.endpoint == "https://api.together.xyz/v1"

def test_vertex_ai_enterprise_scenario(self) -> None:
"""
Test enterprise Vertex AI configuration scenario.
"""
env_content = {
"ORCHESTRATOR_PROVIDER": "google",
"ORCHESTRATOR_MODEL": "gemini-2.5-pro",
Expand All @@ -162,9 +162,63 @@ def test_vertex_ai_enterprise_scenario(self) -> None:
assert orchestrator.model_id == "gemini-2.5-pro"
assert orchestrator.project_id == "my-enterprise-project"
assert orchestrator.region == "us-central1"
assert orchestrator.provider_type == "vertex"
assert orchestrator.provider_type == GoogleProviderType.VERTEX
assert orchestrator.service_account_file == "/path/to/service-account.json"

def test_vertex_ai_skips_api_key_validation(self) -> None:
env_content = {
"ORCHESTRATOR_PROVIDER": "google",
"ORCHESTRATOR_MODEL": "gemini-2.5-pro",
"ORCHESTRATOR_PROJECT_ID": "my-project",
"ORCHESTRATOR_REGION": "us-central1",
"ORCHESTRATOR_PROVIDER_TYPE": "vertex",
"ORCHESTRATOR_SERVICE_ACCOUNT_FILE": "/path/to/sa.json",
"CYPHER_PROVIDER": "google",
"CYPHER_MODEL": "gemini-2.5-flash",
"CYPHER_PROJECT_ID": "my-project",
"CYPHER_REGION": "us-central1",
"CYPHER_PROVIDER_TYPE": "vertex",
"CYPHER_SERVICE_ACCOUNT_FILE": "/path/to/sa.json",
}

with patch.dict(os.environ, env_content):
config = AppConfig()

orchestrator = config.active_orchestrator_config
orchestrator.validate_api_key("orchestrator")

cypher = config.active_cypher_config
cypher.validate_api_key("cypher")

def test_vertex_ai_with_google_api_key_env_does_not_error(self) -> None:
env_content = {
"ORCHESTRATOR_PROVIDER": "google",
"ORCHESTRATOR_MODEL": "gemini-2.5-pro",
"ORCHESTRATOR_PROJECT_ID": "my-project",
"ORCHESTRATOR_PROVIDER_TYPE": "vertex",
"ORCHESTRATOR_SERVICE_ACCOUNT_FILE": "/path/to/sa.json",
"GOOGLE_API_KEY": "stray-key-from-env",
}

with patch.dict(os.environ, env_content):
config = AppConfig()
orchestrator = config.active_orchestrator_config
orchestrator.validate_api_key("orchestrator")

def test_google_gla_without_api_key_raises(self) -> None:
env_content = {
"ORCHESTRATOR_PROVIDER": "google",
"ORCHESTRATOR_MODEL": "gemini-2.5-pro",
"ORCHESTRATOR_PROVIDER_TYPE": "gla",
"ORCHESTRATOR_API_KEY": "",
}

with patch.dict(os.environ, env_content):
config = AppConfig()
orchestrator = config.active_orchestrator_config
with pytest.raises(ValueError, match="API Key Missing"):
orchestrator.validate_api_key("orchestrator")

def test_reasoning_model_thinking_budget(self) -> None:
"""
Test configuration for reasoning models with thinking budget.
Expand Down
9 changes: 7 additions & 2 deletions codebase_rag/types_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from prompt_toolkit.styles import Style

from .constants import NodeLabel, RelationshipType, SupportedLanguage
from .constants import (
GoogleProviderType,
NodeLabel,
RelationshipType,
SupportedLanguage,
)

if TYPE_CHECKING:
from tree_sitter import Language, Node, Parser, Query
Expand Down Expand Up @@ -148,7 +153,7 @@ class ModelConfigKwargs(TypedDict, total=False):
endpoint: str | None
project_id: str | None
region: str | None
provider_type: str | None
provider_type: GoogleProviderType | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in ModelConfig, keeping provider_type as str | None in ModelConfigKwargs ensures that the type remains generic and extensible for other providers.

Suggested change
provider_type: GoogleProviderType | None
provider_type: str | None

thinking_budget: int | None
service_account_file: str | None

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.