- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1
Fix OpenAI per-call API key handling #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -95,13 +95,14 @@ def _resolve_translation_service() -> TranslationService: | |
| ) | ||
| return service | ||
|  | ||
| def get_headers(self) -> dict[str, str]: | ||
| def get_headers(self, api_key: str | None = None) -> dict[str, str]: | ||
| """Return request headers including API key and per-request identity.""" | ||
|  | ||
| headers: dict[str, str] = {} | ||
|  | ||
| if self.api_key: | ||
| headers["Authorization"] = f"Bearer {self.api_key}" | ||
| effective_api_key = api_key if api_key is not None else self.api_key | ||
| if effective_api_key: | ||
| headers["Authorization"] = f"Bearer {effective_api_key}" | ||
|  | ||
| if self.identity: | ||
| try: | ||
|  | @@ -143,7 +144,7 @@ async def initialize(self, **kwargs: Any) -> None: | |
| logger.warning("Failed to fetch models: %s", e, exc_info=True) | ||
| # Log the error but don't fail initialization | ||
|  | ||
| async def _perform_health_check(self) -> bool: | ||
| async def _perform_health_check(self, api_key: str | None = None) -> bool: | ||
| """Perform a health check by testing API connectivity. | ||
|  | ||
| This method tests actual API connectivity by making a simple request to verify | ||
|  | @@ -154,11 +155,12 @@ async def _perform_health_check(self) -> bool: | |
| """ | ||
| try: | ||
| # Test API connectivity with a simple models endpoint request | ||
| if not self.api_key: | ||
| effective_api_key = api_key if api_key is not None else self.api_key | ||
| if not effective_api_key: | ||
| logger.warning("Health check failed - no API key available") | ||
| return False | ||
|  | ||
| headers = self.get_headers() | ||
| headers = self.get_headers(effective_api_key) | ||
| if not headers.get("Authorization"): | ||
| logger.warning("Health check failed - no authorization header") | ||
| return False | ||
|  | @@ -183,7 +185,7 @@ async def _perform_health_check(self) -> bool: | |
| ) | ||
| return False | ||
|  | ||
| async def _ensure_healthy(self) -> None: | ||
| async def _ensure_healthy(self, api_key: str | None = None) -> None: | ||
| """Ensure the backend is healthy before use. | ||
|  | ||
| This method performs health checks on first use, similar to how | ||
|  | @@ -198,7 +200,7 @@ async def _ensure_healthy(self) -> None: | |
| f"Performing first-use health check for {self.backend_type} backend" | ||
| ) | ||
|  | ||
| healthy = await self._perform_health_check() | ||
| healthy = await self._perform_health_check(api_key) | ||
| if not healthy: | ||
| logger.warning( | ||
| "Health check did not pass; continuing with lazy verification on first request" | ||
|  | @@ -225,88 +227,97 @@ async def chat_completions( | |
| processed_messages: list[Any], | ||
| effective_model: str, | ||
| identity: IAppIdentityConfig | None = None, | ||
| api_key: str | None = None, | ||
| **kwargs: Any, | ||
| ) -> ResponseEnvelope | StreamingResponseEnvelope: | ||
| # Perform health check if enabled (for subclasses that support it) | ||
| await self._ensure_healthy() | ||
|  | ||
| # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) | ||
| # (the frontend controller converts from frontend-specific format to domain format) | ||
| # Backends should ONLY convert FROM domain TO backend-specific format | ||
| # Type assertion: we know from architectural design that request_data is ChatRequest-like | ||
| from typing import cast | ||
| original_identity = self.identity | ||
|  | ||
| from src.core.domain.chat import CanonicalChatRequest, ChatRequest | ||
| # Allow callers to supply a one-off API key (e.g., multi-tenant flows). | ||
| effective_api_key = api_key if api_key is not None else self.api_key | ||
|  | ||
| if not isinstance(request_data, ChatRequest): | ||
| raise TypeError( | ||
| f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " | ||
| "Backend connectors should only receive domain-format requests." | ||
| ) | ||
| # Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature | ||
| domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) | ||
|  | ||
| # Ensure identity headers are scoped to the current request only. | ||
| self.identity = identity | ||
| # Perform health check if enabled (for subclasses that support it) | ||
| try: | ||
| await self._ensure_healthy(effective_api_key) | ||
|  | ||
| # Prepare the payload using a helper so subclasses and tests can | ||
| # override or patch payload construction logic easily. | ||
| payload = await self._prepare_payload( | ||
| domain_request, processed_messages, effective_model | ||
| ) | ||
| headers_override = kwargs.pop("headers_override", None) | ||
| headers: dict[str, str] | None = None | ||
| # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) | ||
| # (the frontend controller converts from frontend-specific format to domain format) | ||
| # Backends should ONLY convert FROM domain TO backend-specific format | ||
| # Type assertion: we know from architectural design that request_data is ChatRequest-like | ||
| from typing import cast | ||
|  | ||
| if headers_override is not None: | ||
| # Avoid mutating the caller-provided mapping while preserving any | ||
| # Authorization header we compute from the configured API key. | ||
| headers = dict(headers_override) | ||
| from src.core.domain.chat import CanonicalChatRequest, ChatRequest | ||
|  | ||
| try: | ||
| base_headers = self.get_headers() | ||
| except Exception: | ||
| base_headers = None | ||
|  | ||
| if base_headers: | ||
| merged_headers = dict(base_headers) | ||
| merged_headers.update(headers) | ||
| headers = merged_headers | ||
| else: | ||
| try: | ||
| # Always update the cached identity so that per-request | ||
| # identity headers do not leak between calls. Downstream | ||
| # callers rely on identity-specific headers being scoped to | ||
| # a single request. | ||
| self.identity = identity | ||
| headers = self.get_headers() | ||
| except Exception: | ||
| headers = None | ||
| if not isinstance(request_data, ChatRequest): | ||
| raise TypeError( | ||
| f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " | ||
| "Backend connectors should only receive domain-format requests." | ||
| ) | ||
| # Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature | ||
| domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) | ||
|  | ||
| api_base = kwargs.get("openai_url") or self.api_base_url | ||
| url = f"{api_base.rstrip('/')}/chat/completions" | ||
| # Ensure identity headers are scoped to the current request only. | ||
| self.identity = identity | ||
|  | ||
| if domain_request.stream: | ||
| # Return a domain-level streaming envelope (raw bytes iterator) | ||
| try: | ||
| content_iterator = await self._handle_streaming_response( | ||
| url, | ||
| payload, | ||
| headers, | ||
| domain_request.session_id or "", | ||
| "openai", | ||
| ) | ||
| except AuthenticationError as e: | ||
| raise HTTPException(status_code=401, detail=str(e)) | ||
| return StreamingResponseEnvelope( | ||
| content=content_iterator, | ||
| media_type="text/event-stream", | ||
| headers={}, | ||
| ) | ||
| else: | ||
| # Return a domain ResponseEnvelope for non-streaming | ||
| return await self._handle_non_streaming_response( | ||
| url, payload, headers, domain_request.session_id or "" | ||
| # Prepare the payload using a helper so subclasses and tests can | ||
| # override or patch payload construction logic easily. | ||
| payload = await self._prepare_payload( | ||
| domain_request, processed_messages, effective_model | ||
| ) | ||
| headers_override = kwargs.pop("headers_override", None) | ||
| headers: dict[str, str] | None = None | ||
|  | ||
| if headers_override is not None: | ||
| # Avoid mutating the caller-provided mapping while preserving any | ||
| # Authorization header we compute from the configured API key. | ||
| headers = dict(headers_override) | ||
|  | ||
| try: | ||
| base_headers = self.get_headers(effective_api_key) | ||
| except Exception: | ||
| base_headers = None | ||
| 
      Comment on lines
    
      +274
     to 
      +277
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catch specific exceptions and log errors. These blocks catch broad  Apply this diff to add logging:              if headers_override is not None:
                 headers = dict(headers_override)
                 try:
                     base_headers = self.get_headers(effective_api_key)
-                except Exception:
+                except Exception as e:
+                    logger.warning("Failed to get base headers: %s", e, exc_info=True)
                     base_headers = None
                 if base_headers:
                     merged_headers = dict(base_headers)
                     merged_headers.update(headers)
                     headers = merged_headers
             else:
                 try:
-                    # Always update the cached identity so that per-request
-                    # identity headers do not leak between calls. Downstream
-                    # callers rely on identity-specific headers being scoped to
-                    # a single request.
-                    self.identity = identity
                     headers = self.get_headers(effective_api_key)
-                except Exception:
+                except Exception as e:
+                    logger.warning("Failed to get headers: %s", e, exc_info=True)
                     headers = NoneBased on coding guidelines. Also applies to: 284-292 | ||
|  | ||
| if base_headers: | ||
| merged_headers = dict(base_headers) | ||
| merged_headers.update(headers) | ||
| headers = merged_headers | ||
| else: | ||
| try: | ||
| # Always update the cached identity so that per-request | ||
| # identity headers do not leak between calls. Downstream | ||
| # callers rely on identity-specific headers being scoped to | ||
| # a single request. | ||
| self.identity = identity | ||
| headers = self.get_headers(effective_api_key) | ||
| except Exception: | ||
| headers = None | ||
|  | ||
| api_base = kwargs.get("openai_url") or self.api_base_url | ||
| url = f"{api_base.rstrip('/')}/chat/completions" | ||
|  | ||
| if domain_request.stream: | ||
| # Return a domain-level streaming envelope (raw bytes iterator) | ||
| try: | ||
| content_iterator = await self._handle_streaming_response( | ||
| url, | ||
| payload, | ||
| headers, | ||
| domain_request.session_id or "", | ||
| "openai", | ||
| ) | ||
| except AuthenticationError as e: | ||
| raise HTTPException(status_code=401, detail=str(e)) | ||
| return StreamingResponseEnvelope( | ||
| content=content_iterator, | ||
| media_type="text/event-stream", | ||
| headers={}, | ||
| ) | ||
| else: | ||
| # Return a domain ResponseEnvelope for non-streaming | ||
| return await self._handle_non_streaming_response( | ||
| url, payload, headers, domain_request.session_id or "" | ||
| ) | ||
| finally: | ||
| self.identity = original_identity | ||
|  | ||
| async def _prepare_payload( | ||
| self, | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Identity mutation creates the same race condition that api_key had.
While the per-call
api_keyhandling correctly avoids mutating instance state, the identity handling still uses the save/mutate/restore pattern that previous reviewers flagged as unsafe:original_identity = self.identityself.identity = identityself.identity = identity(redundant)self.identity = original_identity(restore in finally)If this connector instance services concurrent requests with different identities, the same race condition occurs: one request can overwrite another's identity while both are in flight, and
get_headers()(called at lines 275, 290) will read the wrong identity, producing incorrect headers for one or both requests.Consider passing
identityas a parameter toget_headers()and downstream methods rather than mutatingself.identity, matching the pattern you correctly used forapi_key.Then update call sites to pass
identityexplicitly:Also applies to: 259-259, 289-289, 320-320
🤖 Prompt for AI Agents