Skip to content

Commit

Permalink
Merge branch 'development' into feat/supported-extensible-routes
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jul 17, 2024
2 parents 89442eb + 8fbc61a commit 9585418
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 75 deletions.
5 changes: 2 additions & 3 deletions aidial_sdk/chat_completion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
Message,
Request,
Role,
Tool,
ToolCall,
ToolChoice,
)
from aidial_sdk.chat_completion.request import Stage as RequestStage
from aidial_sdk.chat_completion.request import Tool, ToolCall, ToolChoice
from aidial_sdk.chat_completion.response import Response
from aidial_sdk.chat_completion.stage import Stage
from aidial_sdk.deployment.tokenize import (
Expand Down
4 changes: 2 additions & 2 deletions aidial_sdk/chat_completion/enums.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from enum import Enum


class FinishReason(Enum):
class FinishReason(str, Enum):
STOP = "stop"
LENGTH = "length"
FUNCTION_CALL = "function_call"
TOOL_CALLS = "tool_calls"
CONTENT_FILTER = "content_filter"


class Status(Enum):
class Status(str, Enum):
COMPLETED = "completed"
FAILED = "failed"
11 changes: 10 additions & 1 deletion aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

from aidial_sdk.chat_completion.enums import Status
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.pydantic_v1 import (
ConstrainedFloat,
Expand All @@ -21,7 +22,15 @@ class Attachment(ExtraForbidModel):
reference_url: Optional[StrictStr] = None


class Stage(ExtraForbidModel):
name: StrictStr
status: Status
content: Optional[StrictStr] = None
attachments: Optional[List[Attachment]] = None


class CustomContent(ExtraForbidModel):
stages: Optional[List[Stage]] = None
attachments: Optional[List[Attachment]] = None
state: Optional[Any] = None

Expand All @@ -39,7 +48,7 @@ class ToolCall(ExtraForbidModel):
function: FunctionCall


class Role(Enum):
class Role(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
Expand Down
8 changes: 8 additions & 0 deletions aidial_sdk/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from aidial_sdk.embeddings.base import Embeddings
from aidial_sdk.embeddings.request import (
Attachment,
EmbeddingsMultiModalInput,
EmbeddingsRequestCustomFields,
Request,
)
from aidial_sdk.embeddings.response import Embedding, Response, Usage
16 changes: 8 additions & 8 deletions aidial_sdk/embeddings/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from typing import List, Literal, Optional, Union

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.pydantic_v1 import StrictFloat, StrictInt, StrictStr
from aidial_sdk.utils.pydantic import ExtraForbidModel
Expand All @@ -16,18 +16,18 @@ class AzureEmbeddingsRequest(ExtraForbidModel):
user: Optional[StrictStr] = None


class DialEmbeddingsType(str, Enum):
SYMMETRIC = "symmetric"
DOCUMENT = "document"
QUERY = "query"


class EmbeddingsRequestCustomFields(ExtraForbidModel):
type: Optional[DialEmbeddingsType] = None
type: Optional[StrictStr] = None
instruction: Optional[StrictStr] = None


EmbeddingsMultiModalInput = Union[
StrictStr, Attachment, List[Union[StrictStr, Attachment]]
]


class EmbeddingsRequest(AzureEmbeddingsRequest):
custom_input: Optional[List[EmbeddingsMultiModalInput]] = None
custom_fields: Optional[EmbeddingsRequestCustomFields] = None


Expand Down
38 changes: 22 additions & 16 deletions aidial_sdk/header_propagator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import types
from contextvars import ContextVar
from typing import Optional
from typing import MutableMapping, Optional

import aiohttp
import httpx
Expand Down Expand Up @@ -79,21 +79,12 @@ async def _on_aiohttp_request_start(
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
):
if not str(params.url).startswith(self._dial_url):
return

api_key_val = self._api_key.get()
if api_key_val:
params.headers["api-key"] = api_key_val
self._modify_headers(str(params.url), params.headers)

def _instrument_requests(self):
def instrumented_send(wrapped, instance, args, kwargs):
request: requests.PreparedRequest = args[0]
if request.url and request.url.startswith(self._dial_url):
api_key_val = self._api_key.get()
if api_key_val:
request.headers["api-key"] = api_key_val

self._modify_headers(request.url or "", request.headers)
return wrapped(*args, **kwargs)

wrapt.wrap_function_wrapper(requests.Session, "send", instrumented_send)
Expand All @@ -102,10 +93,7 @@ def _instrument_httpx(self):

def instrumented_build_request(wrapped, instance, args, kwargs):
request: httpx.Request = wrapped(*args, **kwargs)
if request.url and str(request.url).startswith(self._dial_url):
api_key_val = self._api_key.get()
if api_key_val:
request.headers["api-key"] = api_key_val
self._modify_headers(str(request.url), request.headers)
return request

wrapt.wrap_function_wrapper(
Expand All @@ -115,3 +103,21 @@ def instrumented_build_request(wrapped, instance, args, kwargs):
wrapt.wrap_function_wrapper(
httpx.AsyncClient, "build_request", instrumented_build_request
)

def _modify_headers(
self, url: str, headers: MutableMapping[str, str]
) -> None:
if url.startswith(self._dial_url):
api_key = self._api_key.get()
if api_key:
old_api_key = headers.get("api-key")
old_authz = headers.get("Authorization")

if (
old_api_key
and old_authz
and old_authz == f"Bearer {old_api_key}"
):
headers["Authorization"] = f"Bearer {api_key}"

headers["api-key"] = api_key
13 changes: 6 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions tests/header_propagation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,35 @@ class Library(str, Enum):
class Request(BaseModel):
url: str
lib: Library
headers: dict


@app.post("/")
async def handle(request: Request):
url = request.url
lib = request.lib
headers = request.headers

if lib == Library.requests:
response = requests.get(url)
response = requests.get(url, headers=headers)
status_code = response.status_code
content = response.json()

elif lib == Library.httpx_async:
async with httpx.AsyncClient() as client:
response = await client.get(url)
response = await client.get(url, headers=headers)
status_code = response.status_code
content = response.json()

elif lib == Library.httpx_sync:
with httpx.Client() as client:
response = client.get(url)
response = client.get(url, headers=headers)
status_code = response.status_code
content = response.json()

elif lib == Library.aiohttp:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
async with session.get(url, headers=headers) as response:
status_code = response.status
content = await response.json()

Expand Down
11 changes: 0 additions & 11 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from aidial_sdk import DIALApp
from tests.applications.simple_embeddings import SimpleEmbeddings
from tests.utils.endpoint_test import TestCase, run_endpoint_test
from tests.utils.errors import invalid_request_error

deployment = "test-app"
app = DIALApp().add_embeddings(deployment, SimpleEmbeddings())
Expand Down Expand Up @@ -37,16 +36,6 @@
{"input": "a", "custom_fields": {"type": "query"}},
expected_response_1,
),
TestCase(
app,
deployment,
"embeddings",
{"input": "a", "custom_fields": {"type": "hello"}},
invalid_request_error(
"custom_fields.type",
"value is not a valid enumeration member; permitted: 'symmetric', 'document', 'query'",
),
),
TestCase(
app,
deployment,
Expand Down
Loading

0 comments on commit 9585418

Please sign in to comment.