|
| 1 | +"""Module providing a wrapper around OctoAI Compute Service embedding models.""" |
| 2 | + |
| 3 | +from typing import Any, Dict, List, Mapping, Optional |
| 4 | +from pydantic import BaseModel, Extra, Field, root_validator |
| 5 | +from langchain.embeddings.base import Embeddings |
| 6 | +from langchain.utils import get_from_dict_or_env |
| 7 | +from octoai import client |
| 8 | + |
| 9 | +DEFAULT_EMBED_INSTRUCTION = "Represent this input: " |
| 10 | +DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: " |
| 11 | + |
| 12 | + |
| 13 | +class OctoAIEmbeddings(BaseModel, Embeddings): |
| 14 | + """ |
| 15 | + Wrapper around OctoAI Compute Service embedding models. |
| 16 | +
|
| 17 | + The environment variable ``OCTOAI_API_TOKEN`` should be set with your API token, or it can be passed |
| 18 | + as a named parameter to the constructor. |
| 19 | + """ |
| 20 | + endpoint_url: Optional[str] = Field( |
| 21 | + None, description="Endpoint URL to use.") |
| 22 | + model_kwargs: Optional[dict] = Field( |
| 23 | + None, description="Keyword arguments to pass to the model.") |
| 24 | + octoai_api_token: Optional[str] = Field( |
| 25 | + None, description="OCTOAI API Token") |
| 26 | + embed_instruction: str = Field( |
| 27 | + DEFAULT_EMBED_INSTRUCTION, description="Instruction to use for embedding documents.") |
| 28 | + query_instruction: str = Field( |
| 29 | + DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query.") |
| 30 | + |
| 31 | + class Config: |
| 32 | + """Configuration for this pydantic object.""" |
| 33 | + extra = Extra.forbid |
| 34 | + |
| 35 | + @root_validator(allow_reuse=True) |
| 36 | + def validate_environment(cls, values: Dict) -> Dict: |
| 37 | + """Ensure that the API key and python package exist in environment.""" |
| 38 | + values["octoai_api_token"] = get_from_dict_or_env( |
| 39 | + values, "octoai_api_token", "OCTOAI_API_TOKEN") |
| 40 | + values["endpoint_url"] = get_from_dict_or_env( |
| 41 | + values, "endpoint_url", "ENDPOINT_URL") |
| 42 | + return values |
| 43 | + |
| 44 | + @property |
| 45 | + def _identifying_params(self) -> Mapping[str, Any]: |
| 46 | + """Return the identifying parameters.""" |
| 47 | + return {"endpoint_url": self.endpoint_url, "model_kwargs": self.model_kwargs or {}} |
| 48 | + |
| 49 | + def _compute_embeddings(self, texts: List[str], instruction: str) -> List[List[float]]: |
| 50 | + """Common functionality for compute embeddings using a OctoAI instruct model.""" |
| 51 | + embeddings = [] |
| 52 | + octoai_client = client.Client(token=self.octoai_api_token) |
| 53 | + |
| 54 | + for text in texts: |
| 55 | + parameter_payload = { |
| 56 | + "sentence": str([text]),# for item in text]), |
| 57 | + "instruction": str([instruction]),# for item in text]), |
| 58 | + "parameters": self.model_kwargs or {} |
| 59 | + } |
| 60 | + |
| 61 | + try: |
| 62 | + resp_json = octoai_client.infer( |
| 63 | + self.endpoint_url, parameter_payload) |
| 64 | + embedding = resp_json["embeddings"] |
| 65 | + except Exception as e: |
| 66 | + raise ValueError( |
| 67 | + f"Error raised by the inference endpoint: {e}") from e |
| 68 | + |
| 69 | + embeddings.append(embedding) |
| 70 | + |
| 71 | + return embeddings |
| 72 | + |
| 73 | + def embed_documents(self, texts: List[str]) -> List[List[float]]: |
| 74 | + """Compute document embeddings using an OctoAI instruct model.""" |
| 75 | + texts = list(map(lambda x: x.replace("\n", " "), texts)) |
| 76 | + return self._compute_embeddings(texts, self.embed_instruction) |
| 77 | + |
| 78 | + def embed_query(self, text: str) -> List[float]: |
| 79 | + """Compute query embedding using an OctoAI instruct model.""" |
| 80 | + text = text.replace("\n", " ") |
| 81 | + return self._compute_embeddings([text], self.embed_instruction) |
| 82 | + |
0 commit comments