Skip to content

Commit

Permalink
Fix #5976 PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
juliuslipp committed Jun 11, 2023
1 parent 06a25a4 commit 9a43b68
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
24 changes: 16 additions & 8 deletions langchain/embeddings/embaas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper around Embaas embeddings API."""
"""Wrapper around embaas embeddings API."""
from typing import Any, Dict, List, Mapping, Optional
from typing_extensions import TypedDict, NotRequired

import requests
from pydantic import BaseModel, Extra, root_validator
Expand All @@ -12,15 +13,24 @@
EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"


class EmbaasEmbeddingsPayload(TypedDict):
"""Payload for the embaas embeddings API."""

model: str
texts: List[str]
instruction: NotRequired[str]


class EmbaasEmbeddings(BaseModel, Embeddings):
"""Wrapper around Embaas's embedding service.
"""Wrapper around embaas's embedding service.
To use, you should have the
environment variable ``EMBAAS_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
# Initialise with default model and instruction
from langchain.llms import EmbaasEmbeddings
emb = EmbaasEmbeddings()
Expand All @@ -45,7 +55,6 @@ class EmbaasEmbeddings(BaseModel, Embeddings):

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
Expand All @@ -62,14 +71,14 @@ def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying params."""
return {"model": self.model, "instruction": self.instruction}

def _generate_payload(self, texts: List[str]) -> Dict[str, Any]:
def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload:
"""Generates payload for the API request."""
payload = {"texts": texts, "model": self.model}
payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model)
if self.instruction:
payload["instruction"] = self.instruction
return payload

def _handle_request(self, payload: Dict[str, Any]) -> List[List[float]]:
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response."""
headers = {
"Authorization": f"Bearer {self.embaas_api_key}",
Expand Down Expand Up @@ -114,8 +123,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
for i in range(0, len(texts), MAX_BATCH_SIZE)]
embeddings = [self._generate_embeddings(batch) for batch in batches]
# flatten the list of lists into a single list
embeddings = [embedding for batch in embeddings for embedding in batch]
return embeddings
return [embedding for batch in embeddings for embedding in batch]

def embed_query(self, text: str) -> List[float]:
"""Get embeddings for a single text.
Expand Down
38 changes: 30 additions & 8 deletions tests/integration_tests/embeddings/test_embaas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test Embaas embeddings."""
from langchain.embeddings.embaas import EmbaasEmbeddings
"""Test embaas embeddings."""
import responses

from langchain.embeddings.embaas import EmbaasEmbeddings, EMBAAS_API_URL


def test_embaas_embed_documents() -> None:
"""Test Embaas embeddings with multiple texts."""
"""Test embaas embeddings with multiple texts."""
texts = ["foo bar", "bar foo", "foo"]
embedding = EmbaasEmbeddings()
output = embedding.embed_documents(texts)
Expand All @@ -14,25 +16,45 @@ def test_embaas_embed_documents() -> None:


def test_embaas_embed_query() -> None:
"""Test Embaas embeddings with multiple texts."""
texts = "foo"
"""Test embaas embeddings with multiple texts."""
text = "foo"
embeddings = EmbaasEmbeddings()
output = embeddings.embed_query("foo")
output = embeddings.embed_query(text)
assert len(output) == 1024


def test_embaas_embed_query_instruction() -> None:
"""Test Embaas embeddings with a different instruction."""
"""Test embaas embeddings with a different instruction."""
text = "Test"
embeddings = EmbaasEmbeddings(instruction="Query")
instruction = "query"
embeddings = EmbaasEmbeddings(instruction=instruction)
output = embeddings.embed_query(text)
assert len(output) == 1024


def test_embaas_embed_query_model() -> None:
"""Test embaas embeddings with a different model."""
text = "Test"
model = "instructor-large"
instruction = "Represent the query for retrieval"
embeddings = EmbaasEmbeddings(model=model, instruction=instruction)
output = embeddings.embed_query(text)
assert len(output) == 768


@responses.activate
def test_embaas_embed_documents_response() -> None:
"""Test embaas embeddings with multiple texts."""
responses.add(responses.POST, EMBAAS_API_URL,
json={
"data": [
{
'embedding': [0.0] * 1024
}
]
}, status=200)

text = "asd"
embeddings = EmbaasEmbeddings()
output = embeddings.embed_query(text)
assert len(output) == 1024

0 comments on commit 9a43b68

Please sign in to comment.