Skip to content

Commit 4f7d243

Browse files
authored
Merge pull request #31 from JigsawStack/feat/embedding-api
Add support for embedding generation 🚀
2 parents f794292 + 5d58d0e commit 4f7d243

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

jigsawstack/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .summary import Summary, AsyncSummary
1414
from .geo import Geo, AsyncGeo
1515
from .prompt_engine import PromptEngine, AsyncPromptEngine
16+
from .embedding import Embedding, AsyncEmbedding
1617
from .exceptions import JigsawStackError
1718

1819

@@ -110,6 +111,11 @@ def __init__(
110111
api_url=api_url,
111112
disable_request_logging=disable_request_logging,
112113
)
114+
self.embedding = Embedding(
115+
api_key=api_key,
116+
api_url=api_url,
117+
disable_request_logging=disable_request_logging,
118+
).execute
113119

114120

115121
class AsyncJigsawStack:
@@ -215,6 +221,11 @@ def __init__(
215221
api_url=api_url,
216222
disable_request_logging=disable_request_logging,
217223
)
224+
self.embedding = AsyncEmbedding(
225+
api_key=api_key,
226+
api_url=api_url,
227+
disable_request_logging=disable_request_logging,
228+
).execute
218229

219230

220231
# Create a global instance of the Web class

jigsawstack/embedding.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import Any, Dict, List, Union, cast, Literal
2+
from typing_extensions import NotRequired, TypedDict
3+
from .request import Request, RequestConfig
4+
from .async_request import AsyncRequest
5+
from typing import List, Union
6+
from ._config import ClientConfig
7+
8+
9+
class EmbeddingParams(TypedDict):
10+
text: NotRequired[str]
11+
file_content: NotRequired[Any]
12+
type: Literal["text", "text-other", "image", "audio", "pdf"]
13+
url: NotRequired[str]
14+
file_store_key: NotRequired[str]
15+
token_overflow_mode: NotRequired[Literal["truncate", "chunk", "error"]] = "chunk"
16+
17+
18+
class EmbeddingResponse(TypedDict):
19+
success: bool
20+
embeddings: List[List[float]]
21+
chunks: List[str]
22+
23+
24+
class Embedding(ClientConfig):
25+
26+
config: RequestConfig
27+
28+
def __init__(
29+
self,
30+
api_key: str,
31+
api_url: str,
32+
disable_request_logging: Union[bool, None] = False,
33+
):
34+
super().__init__(api_key, api_url, disable_request_logging)
35+
self.config = RequestConfig(
36+
api_url=api_url,
37+
api_key=api_key,
38+
disable_request_logging=disable_request_logging,
39+
)
40+
41+
def execute(self, params: EmbeddingParams) -> EmbeddingResponse:
42+
path = "/embedding"
43+
resp = Request(
44+
config=self.config,
45+
path=path,
46+
params=cast(Dict[Any, Any], params),
47+
verb="post",
48+
).perform_with_content()
49+
return resp
50+
51+
52+
class AsyncEmbedding(ClientConfig):
53+
54+
config: RequestConfig
55+
56+
def __init__(
57+
self,
58+
api_key: str,
59+
api_url: str,
60+
disable_request_logging: Union[bool, None] = False,
61+
):
62+
super().__init__(api_key, api_url, disable_request_logging)
63+
self.config = RequestConfig(
64+
api_url=api_url,
65+
api_key=api_key,
66+
disable_request_logging=disable_request_logging,
67+
)
68+
69+
async def execute(self, params: EmbeddingParams) -> EmbeddingResponse:
70+
path = "/embedding"
71+
resp = await AsyncRequest(
72+
config=self.config,
73+
path=path,
74+
params=cast(Dict[Any, Any], params),
75+
verb="post",
76+
).perform_with_content()
77+
return resp

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name="jigsawstack",
9-
version="0.1.24",
9+
version="0.1.25",
1010
description="JigsawStack Python SDK",
1111
long_description=open("README.md", encoding="utf8").read(),
1212
long_description_content_type="text/markdown",

tests/test_embedding_async.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from unittest.mock import MagicMock
2+
import unittest
3+
from jigsawstack.exceptions import JigsawStackError
4+
from jigsawstack import AsyncJigsawStack
5+
import pytest
6+
import asyncio
7+
import logging
8+
9+
logging.basicConfig(level=logging.INFO)
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def test_async_embedding_generation_response():
14+
async def _test():
15+
client = AsyncJigsawStack()
16+
try:
17+
result = await client.embedding({"text": "Hello, World!", "type": "text"})
18+
logger.info(result)
19+
assert result["success"] == True
20+
except JigsawStackError as e:
21+
pytest.fail(f"Unexpected JigsawStackError: {e}")
22+
23+
asyncio.run(_test())

0 commit comments

Comments
 (0)