Skip to content

Commit 9c4a86e

Browse files
Merge pull request #13868 from kankute-sameer/litellm_feat_voyage_context_3_embedding_model
[Feat] Add support for voyage-context-3 embedding model
2 parents 6214d22 + 5ac4fb5 commit 9c4a86e

File tree

6 files changed

+575
-35
lines changed

6 files changed

+575
-35
lines changed

litellm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,7 @@ def add_known_models():
11471147
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
11481148
from .llms.groq.chat.transformation import GroqChatConfig
11491149
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
1150+
from .llms.voyage.embedding.transformation_contextual import VoyageContextualEmbeddingConfig
11501151
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
11511152
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
11521153
from .llms.mistral.chat.transformation import MistralConfig
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
This module is used to transform the request and response for the Voyage contextualized embeddings API.
3+
This would be used for all the contextualized embeddings models in Voyage.
4+
"""
5+
from typing import List, Optional, Union
6+
7+
import httpx
8+
9+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
10+
from litellm.llms.base_llm.chat.transformation import BaseLLMException
11+
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
12+
from litellm.secret_managers.main import get_secret_str
13+
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
14+
from litellm.types.utils import EmbeddingResponse, Usage
15+
16+
17+
class VoyageError(BaseLLMException):
18+
def __init__(
19+
self,
20+
status_code: int,
21+
message: str,
22+
headers: Union[dict, httpx.Headers] = {},
23+
):
24+
self.status_code = status_code
25+
self.message = message
26+
self.request = httpx.Request(
27+
method="POST", url="https://api.voyageai.com/v1/contextualizedembeddings"
28+
)
29+
self.response = httpx.Response(status_code=status_code, request=self.request)
30+
super().__init__(
31+
status_code=status_code,
32+
message=message,
33+
headers=headers,
34+
)
35+
36+
37+
class VoyageContextualEmbeddingConfig(BaseEmbeddingConfig):
38+
"""
39+
Reference: https://docs.voyageai.com/reference/embeddings-api
40+
"""
41+
42+
def __init__(self) -> None:
43+
pass
44+
45+
def get_complete_url(
46+
self,
47+
api_base: Optional[str],
48+
api_key: Optional[str],
49+
model: str,
50+
optional_params: dict,
51+
litellm_params: dict,
52+
stream: Optional[bool] = None,
53+
) -> str:
54+
if api_base:
55+
if not api_base.endswith("/contextualizedembeddings"):
56+
api_base = f"{api_base}/contextualizedembeddings"
57+
return api_base
58+
return "https://api.voyageai.com/v1/contextualizedembeddings"
59+
60+
def get_supported_openai_params(self, model: str) -> list:
61+
return ["encoding_format", "dimensions"]
62+
63+
def map_openai_params(
64+
self,
65+
non_default_params: dict,
66+
optional_params: dict,
67+
model: str,
68+
drop_params: bool,
69+
) -> dict:
70+
"""
71+
Map OpenAI params to Voyage params
72+
73+
Reference: https://docs.voyageai.com/reference/contextualized-embeddings-api
74+
"""
75+
if "encoding_format" in non_default_params:
76+
optional_params["encoding_format"] = non_default_params["encoding_format"]
77+
if "dimensions" in non_default_params:
78+
optional_params["output_dimension"] = non_default_params["dimensions"]
79+
return optional_params
80+
81+
def validate_environment(
82+
self,
83+
headers: dict,
84+
model: str,
85+
messages: List[AllMessageValues],
86+
optional_params: dict,
87+
litellm_params: dict,
88+
api_key: Optional[str] = None,
89+
api_base: Optional[str] = None,
90+
) -> dict:
91+
if api_key is None:
92+
api_key = (
93+
get_secret_str("VOYAGE_API_KEY")
94+
or get_secret_str("VOYAGE_AI_API_KEY")
95+
or get_secret_str("VOYAGE_AI_TOKEN")
96+
)
97+
return {
98+
"Authorization": f"Bearer {api_key}",
99+
}
100+
101+
def transform_embedding_request(
102+
self,
103+
model: str,
104+
input: Union[AllEmbeddingInputValues, List[List[str]]],
105+
optional_params: dict,
106+
headers: dict,
107+
) -> dict:
108+
return {
109+
"inputs": input,
110+
"model": model,
111+
**optional_params,
112+
}
113+
114+
def transform_embedding_response(
115+
self,
116+
model: str,
117+
raw_response: httpx.Response,
118+
model_response: EmbeddingResponse,
119+
logging_obj: LiteLLMLoggingObj,
120+
api_key: Optional[str] = None,
121+
request_data: dict = {},
122+
optional_params: dict = {},
123+
litellm_params: dict = {},
124+
) -> EmbeddingResponse:
125+
try:
126+
raw_response_json = raw_response.json()
127+
except Exception:
128+
raise VoyageError(
129+
message=raw_response.text, status_code=raw_response.status_code
130+
)
131+
132+
# model_response.usage
133+
model_response.model = raw_response_json.get("model")
134+
model_response.data = raw_response_json.get("data")
135+
model_response.object = raw_response_json.get("object")
136+
137+
usage = Usage(
138+
prompt_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
139+
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
140+
)
141+
model_response.usage = usage
142+
return model_response
143+
144+
def get_error_class(
145+
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
146+
) -> BaseLLMException:
147+
return VoyageError(
148+
message=error_message, status_code=status_code, headers=headers
149+
)
150+
151+
@staticmethod
152+
def is_contextualized_embeddings(model: str) -> bool:
153+
return "context" in model.lower()

litellm/model_prices_and_context_window_backup.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16710,6 +16710,14 @@
1671016710
"litellm_provider": "voyage",
1671116711
"mode": "embedding"
1671216712
},
16713+
"voyage/voyage-context-3": {
16714+
"max_tokens": 120000,
16715+
"max_input_tokens": 120000,
16716+
"input_cost_per_token": 1.8e-07,
16717+
"output_cost_per_token": 0.0,
16718+
"litellm_provider": "voyage",
16719+
"mode": "embedding"
16720+
},
1671316721
"voyage/rerank-2": {
1671416722
"max_tokens": 16000,
1671516723
"max_input_tokens": 16000,

litellm/utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,12 +2802,22 @@ def _check_valid_arg(supported_params: Optional[list]):
28022802
request_type="embeddings",
28032803
)
28042804
_check_valid_arg(supported_params=supported_params)
2805-
optional_params = litellm.VoyageEmbeddingConfig().map_openai_params(
2806-
non_default_params=non_default_params,
2807-
optional_params={},
2808-
model=model,
2809-
drop_params=drop_params if drop_params is not None else False,
2810-
)
2805+
if litellm.VoyageContextualEmbeddingConfig.is_contextualized_embeddings(model):
2806+
optional_params = (
2807+
litellm.VoyageContextualEmbeddingConfig().map_openai_params(
2808+
non_default_params=non_default_params,
2809+
optional_params={},
2810+
model=model,
2811+
drop_params=drop_params if drop_params is not None else False,
2812+
)
2813+
)
2814+
else:
2815+
optional_params = litellm.VoyageEmbeddingConfig().map_openai_params(
2816+
non_default_params=non_default_params,
2817+
optional_params={},
2818+
model=model,
2819+
drop_params=drop_params if drop_params is not None else False,
2820+
)
28112821
elif custom_llm_provider == "infinity":
28122822
supported_params = get_supported_openai_params(
28132823
model=model,
@@ -7020,7 +7030,14 @@ def get_provider_embedding_config(
70207030
model: str,
70217031
provider: LlmProviders,
70227032
) -> Optional[BaseEmbeddingConfig]:
7023-
if litellm.LlmProviders.VOYAGE == provider:
7033+
if (
7034+
litellm.LlmProviders.VOYAGE == provider
7035+
and litellm.VoyageContextualEmbeddingConfig.is_contextualized_embeddings(
7036+
model
7037+
)
7038+
):
7039+
return litellm.VoyageContextualEmbeddingConfig()
7040+
elif litellm.LlmProviders.VOYAGE == provider:
70247041
return litellm.VoyageEmbeddingConfig()
70257042
elif litellm.LlmProviders.TRITON == provider:
70267043
return litellm.TritonEmbeddingConfig()

model_prices_and_context_window.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17611,6 +17611,14 @@
1761117611
"litellm_provider": "voyage",
1761217612
"mode": "embedding"
1761317613
},
17614+
"voyage/voyage-context-3": {
17615+
"max_tokens": 120000,
17616+
"max_input_tokens": 120000,
17617+
"input_cost_per_token": 1.8e-07,
17618+
"output_cost_per_token": 0.0,
17619+
"litellm_provider": "voyage",
17620+
"mode": "embedding"
17621+
},
1761417622
"voyage/rerank-2": {
1761517623
"max_tokens": 16000,
1761617624
"max_input_tokens": 16000,

0 commit comments

Comments
 (0)