Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 91 additions & 80 deletions src/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,14 @@ def _resolve_translation_service() -> TranslationService:
)
return service

def get_headers(self) -> dict[str, str]:
def get_headers(self, api_key: str | None = None) -> dict[str, str]:
"""Return request headers including API key and per-request identity."""

headers: dict[str, str] = {}

if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
effective_api_key = api_key if api_key is not None else self.api_key
if effective_api_key:
headers["Authorization"] = f"Bearer {effective_api_key}"

if self.identity:
try:
Expand Down Expand Up @@ -143,7 +144,7 @@ async def initialize(self, **kwargs: Any) -> None:
logger.warning("Failed to fetch models: %s", e, exc_info=True)
# Log the error but don't fail initialization

async def _perform_health_check(self) -> bool:
async def _perform_health_check(self, api_key: str | None = None) -> bool:
"""Perform a health check by testing API connectivity.

This method tests actual API connectivity by making a simple request to verify
Expand All @@ -154,11 +155,12 @@ async def _perform_health_check(self) -> bool:
"""
try:
# Test API connectivity with a simple models endpoint request
if not self.api_key:
effective_api_key = api_key if api_key is not None else self.api_key
if not effective_api_key:
logger.warning("Health check failed - no API key available")
return False

headers = self.get_headers()
headers = self.get_headers(effective_api_key)
if not headers.get("Authorization"):
logger.warning("Health check failed - no authorization header")
return False
Expand All @@ -183,7 +185,7 @@ async def _perform_health_check(self) -> bool:
)
return False

async def _ensure_healthy(self) -> None:
async def _ensure_healthy(self, api_key: str | None = None) -> None:
"""Ensure the backend is healthy before use.

This method performs health checks on first use, similar to how
Expand All @@ -198,7 +200,7 @@ async def _ensure_healthy(self) -> None:
f"Performing first-use health check for {self.backend_type} backend"
)

healthy = await self._perform_health_check()
healthy = await self._perform_health_check(api_key)
if not healthy:
logger.warning(
"Health check did not pass; continuing with lazy verification on first request"
Expand All @@ -225,88 +227,97 @@ async def chat_completions(
processed_messages: list[Any],
effective_model: str,
identity: IAppIdentityConfig | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> ResponseEnvelope | StreamingResponseEnvelope:
# Perform health check if enabled (for subclasses that support it)
await self._ensure_healthy()

# request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest)
# (the frontend controller converts from frontend-specific format to domain format)
# Backends should ONLY convert FROM domain TO backend-specific format
# Type assertion: we know from architectural design that request_data is ChatRequest-like
from typing import cast
original_identity = self.identity
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Identity mutation creates the same race condition that api_key had.

While the per-call api_key handling correctly avoids mutating instance state, the identity handling still uses the save/mutate/restore pattern that previous reviewers flagged as unsafe:

  • Line 233: original_identity = self.identity
  • Line 259: self.identity = identity
  • Line 289: self.identity = identity (redundant)
  • Line 320: self.identity = original_identity (restore in finally)

If this connector instance services concurrent requests with different identities, the same race condition occurs: one request can overwrite another's identity while both are in flight, and get_headers() (called at lines 275, 290) will read the wrong identity, producing incorrect headers for one or both requests.

Consider passing identity as a parameter to get_headers() and downstream methods rather than mutating self.identity, matching the pattern you correctly used for api_key.

-    def get_headers(self, api_key: str | None = None) -> dict[str, str]:
+    def get_headers(self, api_key: str | None = None, identity: IAppIdentityConfig | None = None) -> dict[str, str]:
         """Return request headers including API key and per-request identity."""
         headers: dict[str, str] = {}
         effective_api_key = api_key if api_key is not None else self.api_key
         if effective_api_key:
             headers["Authorization"] = f"Bearer {effective_api_key}"
-        if self.identity:
+        effective_identity = identity if identity is not None else self.identity
+        if effective_identity:
             try:
-                identity_headers = self.identity.get_resolved_headers(None)
+                identity_headers = effective_identity.get_resolved_headers(None)
             except Exception:
                 identity_headers = {}
             if identity_headers:
                 headers.update(identity_headers)
         return ensure_loop_guard_header(headers)

Then update call sites to pass identity explicitly:

-                    base_headers = self.get_headers(effective_api_key)
+                    base_headers = self.get_headers(effective_api_key, identity)

Also applies to: 259-259, 289-289, 320-320

🤖 Prompt for AI Agents
In src/connectors/openai.py around lines 233-320, do not mutate self.identity
(originally saved at line 233 and modified at lines 259, 289, restored at 320)
because it creates a race condition for concurrent requests; instead add an
identity parameter to get_headers() and any downstream helper methods, remove
all assignments to self.identity (including the redundant line at 289 and the
restore in finally at 320), and update all call sites within this block to pass
the local identity variable explicitly so each request uses its own identity
without touching instance state.


from src.core.domain.chat import CanonicalChatRequest, ChatRequest
# Allow callers to supply a one-off API key (e.g., multi-tenant flows).
effective_api_key = api_key if api_key is not None else self.api_key

if not isinstance(request_data, ChatRequest):
raise TypeError(
f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. "
"Backend connectors should only receive domain-format requests."
)
# Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature
domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data)

# Ensure identity headers are scoped to the current request only.
self.identity = identity
# Perform health check if enabled (for subclasses that support it)
try:
await self._ensure_healthy(effective_api_key)

# Prepare the payload using a helper so subclasses and tests can
# override or patch payload construction logic easily.
payload = await self._prepare_payload(
domain_request, processed_messages, effective_model
)
headers_override = kwargs.pop("headers_override", None)
headers: dict[str, str] | None = None
# request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest)
# (the frontend controller converts from frontend-specific format to domain format)
# Backends should ONLY convert FROM domain TO backend-specific format
# Type assertion: we know from architectural design that request_data is ChatRequest-like
from typing import cast

if headers_override is not None:
# Avoid mutating the caller-provided mapping while preserving any
# Authorization header we compute from the configured API key.
headers = dict(headers_override)
from src.core.domain.chat import CanonicalChatRequest, ChatRequest

try:
base_headers = self.get_headers()
except Exception:
base_headers = None

if base_headers:
merged_headers = dict(base_headers)
merged_headers.update(headers)
headers = merged_headers
else:
try:
# Always update the cached identity so that per-request
# identity headers do not leak between calls. Downstream
# callers rely on identity-specific headers being scoped to
# a single request.
self.identity = identity
headers = self.get_headers()
except Exception:
headers = None
if not isinstance(request_data, ChatRequest):
raise TypeError(
f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. "
"Backend connectors should only receive domain-format requests."
)
# Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature
domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data)

api_base = kwargs.get("openai_url") or self.api_base_url
url = f"{api_base.rstrip('/')}/chat/completions"
# Ensure identity headers are scoped to the current request only.
self.identity = identity

if domain_request.stream:
# Return a domain-level streaming envelope (raw bytes iterator)
try:
content_iterator = await self._handle_streaming_response(
url,
payload,
headers,
domain_request.session_id or "",
"openai",
)
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e))
return StreamingResponseEnvelope(
content=content_iterator,
media_type="text/event-stream",
headers={},
)
else:
# Return a domain ResponseEnvelope for non-streaming
return await self._handle_non_streaming_response(
url, payload, headers, domain_request.session_id or ""
# Prepare the payload using a helper so subclasses and tests can
# override or patch payload construction logic easily.
payload = await self._prepare_payload(
domain_request, processed_messages, effective_model
)
headers_override = kwargs.pop("headers_override", None)
headers: dict[str, str] | None = None

if headers_override is not None:
# Avoid mutating the caller-provided mapping while preserving any
# Authorization header we compute from the configured API key.
headers = dict(headers_override)

try:
base_headers = self.get_headers(effective_api_key)
except Exception:
base_headers = None
Comment on lines +274 to +277
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Catch specific exceptions and log errors.

These blocks catch broad Exception without logging, making debugging difficult when get_headers() fails. Per coding guidelines: "If a broad exception must be caught, log with exc_info=True and re-raise a specific custom exception."

Apply this diff to add logging:

             if headers_override is not None:
                 headers = dict(headers_override)
                 try:
                     base_headers = self.get_headers(effective_api_key)
-                except Exception:
+                except Exception as e:
+                    logger.warning("Failed to get base headers: %s", e, exc_info=True)
                     base_headers = None
                 if base_headers:
                     merged_headers = dict(base_headers)
                     merged_headers.update(headers)
                     headers = merged_headers
             else:
                 try:
-                    # Always update the cached identity so that per-request
-                    # identity headers do not leak between calls. Downstream
-                    # callers rely on identity-specific headers being scoped to
-                    # a single request.
-                    self.identity = identity
                     headers = self.get_headers(effective_api_key)
-                except Exception:
+                except Exception as e:
+                    logger.warning("Failed to get headers: %s", e, exc_info=True)
                     headers = None

Based on coding guidelines.

Also applies to: 284-292


if base_headers:
merged_headers = dict(base_headers)
merged_headers.update(headers)
headers = merged_headers
else:
try:
# Always update the cached identity so that per-request
# identity headers do not leak between calls. Downstream
# callers rely on identity-specific headers being scoped to
# a single request.
self.identity = identity
headers = self.get_headers(effective_api_key)
except Exception:
headers = None

api_base = kwargs.get("openai_url") or self.api_base_url
url = f"{api_base.rstrip('/')}/chat/completions"

if domain_request.stream:
# Return a domain-level streaming envelope (raw bytes iterator)
try:
content_iterator = await self._handle_streaming_response(
url,
payload,
headers,
domain_request.session_id or "",
"openai",
)
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e))
return StreamingResponseEnvelope(
content=content_iterator,
media_type="text/event-stream",
headers={},
)
else:
# Return a domain ResponseEnvelope for non-streaming
return await self._handle_non_streaming_response(
url, payload, headers, domain_request.session_id or ""
)
finally:
self.identity = original_identity

async def _prepare_payload(
self,
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/connectors/test_precision_payload_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from src.connectors.openrouter import OpenRouterBackend
from src.core.config.app_config import AppConfig
from src.core.domain.chat import ChatMessage, ChatRequest
from src.core.domain.responses import ResponseEnvelope
from src.core.services.translation_service import TranslationService


Expand Down Expand Up @@ -45,6 +46,53 @@ async def fake_post(url: str, json: dict, headers: dict) -> httpx.Response:
assert captured_payload.get("top_p") == 0.34


@pytest.mark.asyncio
async def test_openai_connector_uses_per_call_api_key(
monkeypatch: pytest.MonkeyPatch,
) -> None:
client = httpx.AsyncClient()
connector = OpenAIConnector(
client, AppConfig(), translation_service=TranslationService()
)
connector.disable_health_check()
connector.api_key = None

observed_headers: list[dict[str, str] | None] = []

async def fake_handle(
self: OpenAIConnector,
url: str,
payload: dict[str, Any],
headers: dict[str, str] | None,
session_id: str,
) -> ResponseEnvelope:
observed_headers.append(headers)
return ResponseEnvelope(content={}, status_code=200, headers={})

monkeypatch.setattr(
OpenAIConnector,
"_handle_non_streaming_response",
fake_handle,
)

request = ChatRequest(model="gpt-4o", messages=_messages(), stream=False)

try:
await connector.chat_completions(
request,
request.messages,
request.model,
api_key="per-call-token",
)
finally:
await client.aclose()

assert observed_headers and observed_headers[0] is not None
assert (
observed_headers[0].get("Authorization") == "Bearer per-call-token"
)


@pytest.mark.asyncio
async def test_openai_payload_uses_processed_messages_with_list_content(
monkeypatch: pytest.MonkeyPatch,
Expand Down
Loading