Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embaas embeddings api endpoints #5976

Merged
merged 8 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix #5976 PR comments
  • Loading branch information
juliuslipp committed Jun 11, 2023
commit 9a43b68f4a19e6adce6064ab7c0c5dc6efe1909d
24 changes: 16 additions & 8 deletions langchain/embeddings/embaas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper around Embaas embeddings API."""
"""Wrapper around embaas embeddings API."""
from typing import Any, Dict, List, Mapping, Optional
from typing_extensions import TypedDict, NotRequired

import requests
from pydantic import BaseModel, Extra, root_validator
Expand All @@ -12,15 +13,24 @@
EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"


class EmbaasEmbeddingsPayload(TypedDict):
"""Payload for the embaas embeddings API."""

model: str
texts: List[str]
instruction: NotRequired[str]


class EmbaasEmbeddings(BaseModel, Embeddings):
"""Wrapper around Embaas's embedding service.
"""Wrapper around embaas's embedding service.

To use, you should have the
environment variable ``EMBAAS_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.

Example:
.. code-block:: python
juliuslipp marked this conversation as resolved.
Show resolved Hide resolved

# Initialise with default model and instruction
from langchain.llms import EmbaasEmbeddings
emb = EmbaasEmbeddings()
Expand All @@ -45,7 +55,6 @@ class EmbaasEmbeddings(BaseModel, Embeddings):

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
Expand All @@ -62,14 +71,14 @@ def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying params."""
return {"model": self.model, "instruction": self.instruction}

def _generate_payload(self, texts: List[str]) -> Dict[str, Any]:
def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload:
"""Generates payload for the API request."""
payload = {"texts": texts, "model": self.model}
payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model)
if self.instruction:
payload["instruction"] = self.instruction
return payload

def _handle_request(self, payload: Dict[str, Any]) -> List[List[float]]:
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response."""
headers = {
"Authorization": f"Bearer {self.embaas_api_key}",
Expand Down Expand Up @@ -114,8 +123,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
for i in range(0, len(texts), MAX_BATCH_SIZE)]
embeddings = [self._generate_embeddings(batch) for batch in batches]
# flatten the list of lists into a single list
embeddings = [embedding for batch in embeddings for embedding in batch]
return embeddings
return [embedding for batch in embeddings for embedding in batch]

def embed_query(self, text: str) -> List[float]:
"""Get embeddings for a single text.
Expand Down
38 changes: 30 additions & 8 deletions tests/integration_tests/embeddings/test_embaas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test Embaas embeddings."""
from langchain.embeddings.embaas import EmbaasEmbeddings
"""Test embaas embeddings."""
import responses

from langchain.embeddings.embaas import EmbaasEmbeddings, EMBAAS_API_URL


def test_embaas_embed_documents() -> None:
"""Test Embaas embeddings with multiple texts."""
"""Test embaas embeddings with multiple texts."""
texts = ["foo bar", "bar foo", "foo"]
embedding = EmbaasEmbeddings()
output = embedding.embed_documents(texts)
Expand All @@ -14,25 +16,45 @@ def test_embaas_embed_documents() -> None:


def test_embaas_embed_query() -> None:
"""Test Embaas embeddings with multiple texts."""
texts = "foo"
"""Test embaas embeddings with multiple texts."""
text = "foo"
embeddings = EmbaasEmbeddings()
output = embeddings.embed_query("foo")
output = embeddings.embed_query(text)
assert len(output) == 1024


def test_embaas_embed_query_instruction() -> None:
"""Test Embaas embeddings with a different instruction."""
"""Test embaas embeddings with a different instruction."""
text = "Test"
embeddings = EmbaasEmbeddings(instruction="Query")
instruction = "query"
embeddings = EmbaasEmbeddings(instruction=instruction)
output = embeddings.embed_query(text)
assert len(output) == 1024


def test_embaas_embed_query_model() -> None:
"""Test embaas embeddings with a different model."""
text = "Test"
model = "instructor-large"
instruction = "Represent the query for retrieval"
embeddings = EmbaasEmbeddings(model=model, instruction=instruction)
output = embeddings.embed_query(text)
assert len(output) == 768


@responses.activate
def test_embaas_embed_documents_response() -> None:
"""Test embaas embeddings with multiple texts."""
responses.add(responses.POST, EMBAAS_API_URL,
json={
"data": [
{
'embedding': [0.0] * 1024
}
]
}, status=200)

text = "asd"
embeddings = EmbaasEmbeddings()
output = embeddings.embed_query(text)
assert len(output) == 1024