Skip to content

Commit dcedb0a

Browse files
authored
Merge pull request #2 from octoml/octoai
add embeddings using instructor large endpoint
2 parents c9c31ce + 31eee4e commit dcedb0a

File tree

4 files changed

+132
-22
lines changed

4 files changed

+132
-22
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Module providing a wrapper around OctoAI Compute Service embedding models."""
2+
3+
from typing import Any, Dict, List, Mapping, Optional
4+
from pydantic import BaseModel, Extra, Field, root_validator
5+
from langchain.embeddings.base import Embeddings
6+
from langchain.utils import get_from_dict_or_env
7+
from octoai import client
8+
9+
DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
10+
DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: "
11+
12+
13+
class OctoAIEmbeddings(BaseModel, Embeddings):
14+
"""
15+
Wrapper around OctoAI Compute Service embedding models.
16+
17+
The environment variable ``OCTOAI_API_TOKEN`` should be set with your API token, or it can be passed
18+
as a named parameter to the constructor.
19+
"""
20+
endpoint_url: Optional[str] = Field(
21+
None, description="Endpoint URL to use.")
22+
model_kwargs: Optional[dict] = Field(
23+
None, description="Keyword arguments to pass to the model.")
24+
octoai_api_token: Optional[str] = Field(
25+
None, description="OCTOAI API Token")
26+
embed_instruction: str = Field(
27+
DEFAULT_EMBED_INSTRUCTION, description="Instruction to use for embedding documents.")
28+
query_instruction: str = Field(
29+
DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query.")
30+
31+
class Config:
32+
"""Configuration for this pydantic object."""
33+
extra = Extra.forbid
34+
35+
@root_validator(allow_reuse=True)
36+
def validate_environment(cls, values: Dict) -> Dict:
37+
"""Ensure that the API key and python package exist in environment."""
38+
values["octoai_api_token"] = get_from_dict_or_env(
39+
values, "octoai_api_token", "OCTOAI_API_TOKEN")
40+
values["endpoint_url"] = get_from_dict_or_env(
41+
values, "endpoint_url", "ENDPOINT_URL")
42+
return values
43+
44+
@property
45+
def _identifying_params(self) -> Mapping[str, Any]:
46+
"""Return the identifying parameters."""
47+
return {"endpoint_url": self.endpoint_url, "model_kwargs": self.model_kwargs or {}}
48+
49+
def _compute_embeddings(self, texts: List[str], instruction: str) -> List[List[float]]:
50+
"""Common functionality for compute embeddings using a OctoAI instruct model."""
51+
embeddings = []
52+
octoai_client = client.Client(token=self.octoai_api_token)
53+
54+
for text in texts:
55+
parameter_payload = {
56+
"sentence": str([text]),# for item in text]),
57+
"instruction": str([instruction]),# for item in text]),
58+
"parameters": self.model_kwargs or {}
59+
}
60+
61+
try:
62+
resp_json = octoai_client.infer(
63+
self.endpoint_url, parameter_payload)
64+
embedding = resp_json["embeddings"]
65+
except Exception as e:
66+
raise ValueError(
67+
f"Error raised by the inference endpoint: {e}") from e
68+
69+
embeddings.append(embedding)
70+
71+
return embeddings
72+
73+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
74+
"""Compute document embeddings using an OctoAI instruct model."""
75+
texts = list(map(lambda x: x.replace("\n", " "), texts))
76+
return self._compute_embeddings(texts, self.embed_instruction)
77+
78+
def embed_query(self, text: str) -> List[float]:
79+
"""Compute query embedding using an OctoAI instruct model."""
80+
text = text.replace("\n", " ")
81+
return self._compute_embeddings([text], self.embed_instruction)
82+

langchain/llms/octoai_endpoint.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,20 @@ class OctoAIEndpoint(LLM):
2121
Example:
2222
.. code-block:: python
2323
24-
from langchain.llms import OctoAIEndpoint
25-
endpoint_url = (
26-
"https://endpoint_name-account_id.octoai.cloud"
27-
)
28-
endpoint = OctoAIEndpoint(
29-
endpoint_url=endpoint_url,
30-
octoai_api_token="octoai-api-key"
24+
from langchain.llms.octoai_endpoint import OctoAIEndpoint
25+
OctoAIEndpoint(
26+
octoai_api_token="octoai-api-key",
27+
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
28+
model_kwargs={
29+
"max_new_tokens": 200,
30+
"temperature": 0.75,
31+
"top_p": 0.95,
32+
"repetition_penalty": 1,
33+
"seed": None,
34+
"stop": [],
35+
},
3136
)
37+
3238
"""
3339

3440
endpoint_url: Optional[str] = None
@@ -45,7 +51,7 @@ class Config:
4551

4652
extra = Extra.forbid
4753

48-
@root_validator()
54+
@root_validator(allow_reuse=True)
4955
def validate_environment(cls, values: Dict) -> Dict:
5056
"""Validate that api key and python package exists in environment."""
5157
octoai_api_token = get_from_dict_or_env(
@@ -90,26 +96,23 @@ def _call(
9096
"""
9197
_model_kwargs = self.model_kwargs or {}
9298

93-
# payload json
99+
# Prepare the payload JSON
94100
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
95101

96-
# HTTP headers for authorization
97-
headers = {
98-
"Authorization": f"Bearer {self.octoai_api_token}",
99-
"Content-Type": "application/json",
100-
}
101-
102-
# send request using octaoai sdk
103102
try:
103+
# Initialize the OctoAI client
104104
octoai_client = client.Client(token=self.octoai_api_token)
105+
106+
# Send the request using the OctoAI client
105107
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
106108
text = resp_json["generated_text"]
107109

108110
except Exception as e:
109-
raise ValueError(f"Error raised by inference endpoint: {e}") from e
111+
# Handle any errors raised by the inference endpoint
112+
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
110113

111114
if stop is not None:
112-
# stop tokens when making calls to octoai.
115+
# Apply stop tokens when making calls to OctoAI
113116
text = enforce_stop_tokens(text, stop)
114117

115118
return text
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Test octoai embeddings."""
2+
3+
from langchain.embeddings.octoai_embeddings import (
4+
OctoAIEmbeddings,
5+
)
6+
7+
8+
def test_octoai_embedding_documents() -> None:
9+
"""Test octoai embeddings."""
10+
documents = ["foo bar"]
11+
embedding = OctoAIEmbeddings()
12+
output = embedding.embed_documents(documents)
13+
assert len(output) == 1
14+
assert len(output[0]) == 768
15+
16+
17+
def test_octoai_embedding_query() -> None:
18+
"""Test octoai embeddings."""
19+
document = "foo bar"
20+
embedding = OctoAIEmbeddings()
21+
output = embedding.embed_query(document)
22+
assert len(output) == 1
23+
assert len(output[0]) == 768

tests/integration_tests/llms/test_octoai_endpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010

1111
from tests.integration_tests.llms.utils import assert_llm_equality
1212

13+
1314
def test_octoai_endpoint_text_generation() -> None:
1415
"""Test valid call to OctoAI text generation model."""
1516
llm = OctoAIEndpoint(
1617
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
1718
model_kwargs={
18-
"max_new_tokens": 512,
19+
"max_new_tokens": 200,
1920
"temperature": 0.75,
2021
"top_p": 0.95,
2122
"repetition_penalty": 1,
@@ -32,8 +33,9 @@ def test_octoai_endpoint_text_generation() -> None:
3233
def test_octoai_endpoint_call_error() -> None:
3334
"""Test valid call to OctoAI that errors."""
3435
llm = OctoAIEndpoint(
35-
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
36-
model_kwargs={"max_new_tokens": -1})
36+
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
37+
model_kwargs={"max_new_tokens": -1},
38+
)
3739
with pytest.raises(ValueError):
3840
llm("Which state is Los Angeles in?")
3941

@@ -43,7 +45,7 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
4345
llm = OctoAIEndpoint(
4446
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
4547
model_kwargs={
46-
"max_new_tokens": 512,
48+
"max_new_tokens": 200,
4749
"temperature": 0.75,
4850
"top_p": 0.95,
4951
"repetition_penalty": 1,

0 commit comments

Comments
 (0)