Skip to content

Commit

Permalink
feat: added custom_input for multi-modal embeddings (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 17, 2024
1 parent 2e0e4dd commit 8fbc61a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 24 deletions.
8 changes: 8 additions & 0 deletions aidial_sdk/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from aidial_sdk.embeddings.base import Embeddings
from aidial_sdk.embeddings.request import (
Attachment,
EmbeddingsMultiModalInput,
EmbeddingsRequestCustomFields,
Request,
)
from aidial_sdk.embeddings.response import Embedding, Response, Usage
16 changes: 8 additions & 8 deletions aidial_sdk/embeddings/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from typing import List, Literal, Optional, Union

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.pydantic_v1 import StrictFloat, StrictInt, StrictStr
from aidial_sdk.utils.pydantic import ExtraForbidModel
Expand All @@ -16,18 +16,18 @@ class AzureEmbeddingsRequest(ExtraForbidModel):
user: Optional[StrictStr] = None


class DialEmbeddingsType(str, Enum):
SYMMETRIC = "symmetric"
DOCUMENT = "document"
QUERY = "query"


class EmbeddingsRequestCustomFields(ExtraForbidModel):
type: Optional[DialEmbeddingsType] = None
type: Optional[StrictStr] = None
instruction: Optional[StrictStr] = None


EmbeddingsMultiModalInput = Union[
StrictStr, Attachment, List[Union[StrictStr, Attachment]]
]


class EmbeddingsRequest(AzureEmbeddingsRequest):
custom_input: Optional[List[EmbeddingsMultiModalInput]] = None
custom_fields: Optional[EmbeddingsRequestCustomFields] = None


Expand Down
16 changes: 0 additions & 16 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from tests.applications.simple_embeddings import SimpleEmbeddings
from tests.utils.endpoint_test import TestCase, run_endpoint_test
from tests.utils.errors import Error

simple = SimpleEmbeddings()

Expand Down Expand Up @@ -34,21 +33,6 @@
{"input": "a", "custom_fields": {"type": "query"}},
expected_response_1,
),
TestCase(
simple,
"embeddings",
{"input": "a", "custom_fields": {"type": "hello"}},
Error(
code=400,
error={
"error": {
"message": "Your request contained invalid structure on path custom_fields.type. "
"value is not a valid enumeration member; permitted: 'symmetric', 'document', 'query'",
"type": "invalid_request_error",
}
},
),
),
TestCase(
simple,
"embeddings",
Expand Down

0 comments on commit 8fbc61a

Please sign in to comment.