Skip to content

Allow extra query params to be sent to the OpenAI server #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,24 @@
)
from guidellm.config import settings

__all__ = ["CHAT_COMPLETIONS_PATH", "TEXT_COMPLETIONS_PATH", "OpenAIHTTPBackend"]
__all__ = [
"CHAT_COMPLETIONS",
"CHAT_COMPLETIONS_PATH",
"MODELS",
"TEXT_COMPLETIONS",
"TEXT_COMPLETIONS_PATH",
"OpenAIHTTPBackend",
]


TEXT_COMPLETIONS_PATH = "/v1/completions"
CHAT_COMPLETIONS_PATH = "/v1/chat/completions"

EndpointType = Literal["chat_completions", "models", "text_completions"]
CHAT_COMPLETIONS: EndpointType = "chat_completions"
MODELS: EndpointType = "models"
TEXT_COMPLETIONS: EndpointType = "text_completions"


@Backend.register("openai_http")
class OpenAIHTTPBackend(Backend):
Expand Down Expand Up @@ -53,6 +65,11 @@ class OpenAIHTTPBackend(Backend):
If not provided, the default value from settings is used.
:param max_output_tokens: The maximum number of tokens to request for completions.
If not provided, the default maximum tokens provided from settings is used.
:param extra_query: Query parameters to include in requests to the OpenAI server.
If "chat_completions", "models", or "text_completions" are included as keys,
the values of these keys will be used as the parameters for the respective
endpoint.
If not provided, no extra query parameters are added.
"""

def __init__(
Expand All @@ -66,6 +83,7 @@ def __init__(
http2: Optional[bool] = True,
follow_redirects: Optional[bool] = None,
max_output_tokens: Optional[int] = None,
extra_query: Optional[dict] = None,
):
super().__init__(type_="openai_http")
self._target = target or settings.openai.base_url
Expand Down Expand Up @@ -101,6 +119,7 @@ def __init__(
if max_output_tokens is not None
else settings.openai.max_output_tokens
)
self.extra_query = extra_query
self._async_client: Optional[httpx.AsyncClient] = None

@property
Expand Down Expand Up @@ -174,7 +193,10 @@ async def available_models(self) -> list[str]:
"""
target = f"{self.target}/v1/models"
headers = self._headers()
response = await self._get_async_client().get(target, headers=headers)
params = self._params(MODELS)
response = await self._get_async_client().get(
target, headers=headers, params=params
)
response.raise_for_status()

models = []
Expand Down Expand Up @@ -219,6 +241,7 @@ async def text_completions( # type: ignore[override]
)

headers = self._headers()
params = self._params(TEXT_COMPLETIONS)
payload = self._completions_payload(
orig_kwargs=kwargs,
max_output_tokens=output_token_count,
Expand All @@ -232,14 +255,16 @@ async def text_completions( # type: ignore[override]
request_prompt_tokens=prompt_token_count,
request_output_tokens=output_token_count,
headers=headers,
params=params,
payload=payload,
):
yield resp
except Exception as ex:
logger.error(
"{} request with headers: {} and payload: {} failed: {}",
"{} request with headers: {} and params: {} and payload: {} failed: {}",
self.__class__.__name__,
headers,
params,
payload,
ex,
)
Expand Down Expand Up @@ -291,6 +316,7 @@ async def chat_completions( # type: ignore[override]
"""
logger.debug("{} invocation with args: {}", self.__class__.__name__, locals())
headers = self._headers()
params = self._params(CHAT_COMPLETIONS)
messages = (
content if raw_content else self._create_chat_messages(content=content)
)
Expand All @@ -307,14 +333,16 @@ async def chat_completions( # type: ignore[override]
request_prompt_tokens=prompt_token_count,
request_output_tokens=output_token_count,
headers=headers,
params=params,
payload=payload,
):
yield resp
except Exception as ex:
logger.error(
"{} request with headers: {} and payload: {} failed: {}",
"{} request with headers: {} and params: {} and payload: {} failed: {}",
self.__class__.__name__,
headers,
params,
payload,
ex,
)
Expand Down Expand Up @@ -355,6 +383,19 @@ def _headers(self) -> dict[str, str]:

return headers

def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
if self.extra_query is None:
return {}

if (
CHAT_COMPLETIONS in self.extra_query
or MODELS in self.extra_query
or TEXT_COMPLETIONS in self.extra_query
):
return self.extra_query.get(endpoint_type, {})

return self.extra_query

def _completions_payload(
self, orig_kwargs: Optional[dict], max_output_tokens: Optional[int], **kwargs
) -> dict:
Expand Down Expand Up @@ -451,8 +492,9 @@ async def _iterative_completions_request(
request_id: Optional[str],
request_prompt_tokens: Optional[int],
request_output_tokens: Optional[int],
headers: dict,
payload: dict,
headers: dict[str, str],
params: dict[str, str],
payload: dict[str, Any],
) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]:
if type_ == "text_completions":
target = f"{self.target}{TEXT_COMPLETIONS_PATH}"
Expand All @@ -463,14 +505,16 @@ async def _iterative_completions_request(

logger.info(
"{} making request: {} to target: {} using http2: {} following "
"redirects: {} for timeout: {} with headers: {} and payload: {}",
"redirects: {} for timeout: {} with headers: {} and params: {} and ",
"payload: {}",
self.__class__.__name__,
request_id,
target,
self.http2,
self.follow_redirects,
self.timeout,
headers,
params,
payload,
)

Expand Down Expand Up @@ -498,7 +542,7 @@ async def _iterative_completions_request(
start_time = time.time()

async with self._get_async_client().stream(
"POST", target, headers=headers, json=payload
"POST", target, headers=headers, params=params, json=payload
) as stream:
stream.raise_for_status()

Expand Down Expand Up @@ -542,10 +586,12 @@ async def _iterative_completions_request(
response_output_count = usage["output"]

logger.info(
"{} request: {} with headers: {} and payload: {} completed with: {}",
"{} request: {} with headers: {} and params: {} and payload: {} completed"
"with: {}",
self.__class__.__name__,
request_id,
headers,
params,
payload,
response_value,
)
Expand All @@ -555,6 +601,7 @@ async def _iterative_completions_request(
request_args=RequestArgs(
target=target,
headers=headers,
params=params,
payload=payload,
timeout=self.timeout,
http2=self.http2,
Expand Down
2 changes: 2 additions & 0 deletions src/guidellm/backend/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class RequestArgs(StandardBaseModel):

:param target: The target URL or function for the request.
:param headers: The headers, if any, included in the request such as authorization.
:param params: The query parameters, if any, included in the request.
:param payload: The payload / arguments for the request including the prompt /
content and other configurations.
:param timeout: The timeout for the request in seconds, if any.
Expand All @@ -57,6 +58,7 @@ class RequestArgs(StandardBaseModel):

target: str
headers: dict[str, str]
params: dict[str, str]
payload: dict[str, Any]
timeout: Optional[float] = None
http2: Optional[bool] = None
Expand Down
4 changes: 3 additions & 1 deletion src/guidellm/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:

class SyntheticDatasetCreator(DatasetCreator):
@classmethod
def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003
def is_supported(
cls, data: Any, data_args: Optional[dict[str, Any]] # noqa: ARG003
) -> bool:
if (
isinstance(data, Path)
and data.exists()
Expand Down
2 changes: 2 additions & 0 deletions src/guidellm/scheduler/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def _handle_response(
request_args=RequestArgs(
target=self.backend.target,
headers={},
params={},
payload={},
),
start_time=resolve_start_time,
Expand All @@ -490,6 +491,7 @@ def _handle_response(
request_args=RequestArgs(
target=self.backend.target,
headers={},
params={},
payload={},
),
start_time=response.start_time,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/backend/test_openai_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_openai_http_backend_default_initialization():
assert backend.http2 is True
assert backend.follow_redirects is True
assert backend.max_output_tokens == settings.openai.max_output_tokens
assert backend.extra_query is None


@pytest.mark.smoke
Expand All @@ -32,6 +33,7 @@ def test_openai_http_backend_intialization():
http2=False,
follow_redirects=False,
max_output_tokens=100,
extra_query={"foo": "bar"},
)
assert backend.target == "http://test-target"
assert backend.model == "test-model"
Expand All @@ -42,6 +44,7 @@ def test_openai_http_backend_intialization():
assert backend.http2 is False
assert backend.follow_redirects is False
assert backend.max_output_tokens == 100
assert backend.extra_query == {"foo": "bar"}


@pytest.mark.smoke
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/backend/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_request_args_default_initialization():
args = RequestArgs(
target="http://example.com",
headers={},
params={},
payload={},
)
assert args.timeout is None
Expand All @@ -90,6 +91,7 @@ def test_request_args_initialization():
headers={
"Authorization": "Bearer token",
},
params={},
payload={
"query": "Hello, world!",
},
Expand All @@ -110,6 +112,7 @@ def test_response_args_marshalling():
args = RequestArgs(
target="http://example.com",
headers={"Authorization": "Bearer token"},
params={},
payload={"query": "Hello, world!"},
timeout=10.0,
http2=True,
Expand All @@ -128,6 +131,7 @@ def test_response_summary_default_initialization():
request_args=RequestArgs(
target="http://example.com",
headers={},
params={},
payload={},
),
start_time=0.0,
Expand Down Expand Up @@ -158,6 +162,7 @@ def test_response_summary_initialization():
request_args=RequestArgs(
target="http://example.com",
headers={},
params={},
payload={},
),
start_time=1.0,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/mock_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def _text_prompt_response_generator(
request_args=RequestArgs(
target=self.target,
headers={},
params={},
payload={"prompt": prompt, "output_token_count": output_token_count},
),
iterations=len(tokens),
Expand Down
Loading