Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 85 additions & 73 deletions src/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,88 +225,100 @@ async def chat_completions(
processed_messages: list[Any],
effective_model: str,
identity: IAppIdentityConfig | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> ResponseEnvelope | StreamingResponseEnvelope:
# Perform health check if enabled (for subclasses that support it)
await self._ensure_healthy()

# request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest)
# (the frontend controller converts from frontend-specific format to domain format)
# Backends should ONLY convert FROM domain TO backend-specific format
# Type assertion: we know from architectural design that request_data is ChatRequest-like
from typing import cast

from src.core.domain.chat import CanonicalChatRequest, ChatRequest

if not isinstance(request_data, ChatRequest):
raise TypeError(
f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. "
"Backend connectors should only receive domain-format requests."
)
# Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature
domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data)
# Allow callers to supply a one-off API key (e.g., multi-tenant flows).
# Temporarily replace the connector-level key for the duration of this
# call so that header construction and health checks use it.
original_api_key = self.api_key
if api_key is not None:
self.api_key = api_key

# Ensure identity headers are scoped to the current request only.
self.identity = identity

# Prepare the payload using a helper so subclasses and tests can
# override or patch payload construction logic easily.
payload = await self._prepare_payload(
domain_request, processed_messages, effective_model
)
headers_override = kwargs.pop("headers_override", None)
headers: dict[str, str] | None = None
# Perform health check if enabled (for subclasses that support it)
try:
await self._ensure_healthy()

if headers_override is not None:
# Avoid mutating the caller-provided mapping while preserving any
# Authorization header we compute from the configured API key.
headers = dict(headers_override)
# request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest)
# (the frontend controller converts from frontend-specific format to domain format)
# Backends should ONLY convert FROM domain TO backend-specific format
# Type assertion: we know from architectural design that request_data is ChatRequest-like
from typing import cast

try:
base_headers = self.get_headers()
except Exception:
base_headers = None
from src.core.domain.chat import CanonicalChatRequest, ChatRequest

if base_headers:
merged_headers = dict(base_headers)
merged_headers.update(headers)
headers = merged_headers
else:
try:
# Always update the cached identity so that per-request
# identity headers do not leak between calls. Downstream
# callers rely on identity-specific headers being scoped to
# a single request.
self.identity = identity
headers = self.get_headers()
except Exception:
headers = None
if not isinstance(request_data, ChatRequest):
raise TypeError(
f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. "
"Backend connectors should only receive domain-format requests."
)
# Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature
domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data)

api_base = kwargs.get("openai_url") or self.api_base_url
url = f"{api_base.rstrip('/')}/chat/completions"
# Ensure identity headers are scoped to the current request only.
self.identity = identity

if domain_request.stream:
# Return a domain-level streaming envelope (raw bytes iterator)
try:
content_iterator = await self._handle_streaming_response(
url,
payload,
headers,
domain_request.session_id or "",
"openai",
)
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e))
return StreamingResponseEnvelope(
content=content_iterator,
media_type="text/event-stream",
headers={},
)
else:
# Return a domain ResponseEnvelope for non-streaming
return await self._handle_non_streaming_response(
url, payload, headers, domain_request.session_id or ""
# Prepare the payload using a helper so subclasses and tests can
# override or patch payload construction logic easily.
payload = await self._prepare_payload(
domain_request, processed_messages, effective_model
)
headers_override = kwargs.pop("headers_override", None)
headers: dict[str, str] | None = None

if headers_override is not None:
# Avoid mutating the caller-provided mapping while preserving any
# Authorization header we compute from the configured API key.
headers = dict(headers_override)

try:
base_headers = self.get_headers()
except Exception:
base_headers = None

if base_headers:
merged_headers = dict(base_headers)
merged_headers.update(headers)
headers = merged_headers
else:
try:
# Always update the cached identity so that per-request
# identity headers do not leak between calls. Downstream
# callers rely on identity-specific headers being scoped to
# a single request.
self.identity = identity
headers = self.get_headers()
except Exception:
headers = None

api_base = kwargs.get("openai_url") or self.api_base_url
url = f"{api_base.rstrip('/')}/chat/completions"

if domain_request.stream:
# Return a domain-level streaming envelope (raw bytes iterator)
try:
content_iterator = await self._handle_streaming_response(
url,
payload,
headers,
domain_request.session_id or "",
"openai",
)
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e))
return StreamingResponseEnvelope(
content=content_iterator,
media_type="text/event-stream",
headers={},
)
else:
# Return a domain ResponseEnvelope for non-streaming
return await self._handle_non_streaming_response(
url, payload, headers, domain_request.session_id or ""
)
finally:
if api_key is not None:
self.api_key = original_api_key

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid mutating connector-level api_key for per-call override

The new api_key parameter works by temporarily assigning self.api_key and restoring it in a finally block (src/connectors/openai.py lines 222‑321). This introduces a race condition when the same OpenAIConnector instance services concurrent requests with different per-call keys: the attribute is global to the instance, so one request can overwrite another’s key while both are in flight, and whichever finishes first will reset the field to the original value. In a multi-tenant deployment this can produce Authorization headers that use the wrong tenant’s key or fail authentication sporadically. Consider passing the override down to header construction without mutating shared state (e.g., via local variables) or making the key thread-safe.

Useful? React with 👍 / 👎.


async def _prepare_payload(
self,
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/connectors/test_precision_payload_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from src.connectors.openrouter import OpenRouterBackend
from src.core.config.app_config import AppConfig
from src.core.domain.chat import ChatMessage, ChatRequest
from src.core.domain.responses import ResponseEnvelope
from src.core.services.translation_service import TranslationService


Expand Down Expand Up @@ -45,6 +46,53 @@ async def fake_post(url: str, json: dict, headers: dict) -> httpx.Response:
assert captured_payload.get("top_p") == 0.34


@pytest.mark.asyncio
async def test_openai_connector_uses_per_call_api_key(
monkeypatch: pytest.MonkeyPatch,
) -> None:
client = httpx.AsyncClient()
connector = OpenAIConnector(
client, AppConfig(), translation_service=TranslationService()
)
connector.disable_health_check()
connector.api_key = None

observed_headers: list[dict[str, str] | None] = []

async def fake_handle(
self: OpenAIConnector,
url: str,
payload: dict[str, Any],
headers: dict[str, str] | None,
session_id: str,
) -> ResponseEnvelope:
observed_headers.append(headers)
return ResponseEnvelope(content={}, status_code=200, headers={})

monkeypatch.setattr(
OpenAIConnector,
"_handle_non_streaming_response",
fake_handle,
)

request = ChatRequest(model="gpt-4o", messages=_messages(), stream=False)

try:
await connector.chat_completions(
request,
request.messages,
request.model,
api_key="per-call-token",
)
finally:
await client.aclose()

assert observed_headers and observed_headers[0] is not None
assert (
observed_headers[0].get("Authorization") == "Bearer per-call-token"
)


@pytest.mark.asyncio
async def test_openai_payload_uses_processed_messages_with_list_content(
monkeypatch: pytest.MonkeyPatch,
Expand Down
Loading