Skip to content

Commit ae1c8ec

Browse files
authored
feat: Integrate any-llm-platform for key management and usage tracking (#618)
Integrate any-llm with the any-llm-platform. Key changes include: - Introduce a mechanism to securely fetch provider keys from the any-llm-platform using a challenge-response authentication method. - Send completion usage data (e.g., token counts) to the platform. - Add unit tests and fix failing ones. Signed-off-by: Dimitris Poulopoulos <dimitris@mozilla.ai>
1 parent 6fd3dac commit ae1c8ec

File tree

9 files changed

+1094
-4
lines changed

9 files changed

+1094
-4
lines changed

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@ dependencies = [
1313
"openai>=1.99.3",
1414
"rich",
1515
"httpx",
16-
"typing-extensions"
1716
]
1817

1918
[project.optional-dependencies]
2019

2120
all = [
22-
"any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,cohere,cerebras,fireworks,groq,bedrock,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,llamafile,llamacpp,sagemaker,gateway,zai,minimax]"
21+
"any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,cohere,cerebras,fireworks,groq,bedrock,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax]"
22+
]
23+
24+
platform = [
25+
"typing-extensions",
26+
"bcrypt>=5.0.0",
27+
"cryptography>=46.0.3",
28+
"pynacl>=1.6.0",
2329
]
2430

2531
perplexity = []

src/any_llm/any_llm.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from any_llm.exceptions import MissingApiKeyError, UnsupportedProviderError
1212
from any_llm.tools import prepare_tools
1313
from any_llm.types.completion import ChatCompletion, ChatCompletionMessage, CompletionParams
14-
from any_llm.types.provider import ProviderMetadata
14+
from any_llm.types.provider import PlatformKey, ProviderMetadata
1515
from any_llm.types.responses import Response, ResponseInputParam, ResponsesParams, ResponseStreamEvent
1616
from any_llm.utils.aio import async_iter_to_sync_iter, run_async_in_sync
1717
from any_llm.utils.decorators import BATCH_API_EXPERIMENTAL_MESSAGE, experimental
@@ -87,6 +87,8 @@ class AnyLLM(ABC):
8787
For example, in `gemini` provider, this could include `google.genai.types.Tool`.
8888
"""
8989

90+
ANY_LLM_KEY: str = "ANY_LLM_KEY"
91+
9092
def __init__(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None:
9193
self._verify_no_missing_packages()
9294
self._init_client(
@@ -147,6 +149,36 @@ def _create_provider(
147149
raise ImportError(msg) from e
148150

149151
provider_class: type[AnyLLM] = getattr(module, provider_class_name)
152+
153+
if not api_key:
154+
api_key = os.getenv(cls.ANY_LLM_KEY)
155+
156+
if api_key:
157+
try:
158+
# Validate if the key conforms with the any-api format.
159+
# If it does, any-llm must ask any-api for the corresponding provider key.
160+
PlatformKey(api_key=api_key)
161+
162+
# Import and instantiate PlatformProvider in-place to avoid circular dependency issues.
163+
platform_class_name = "PlatformProvider"
164+
platform_module_path = "any_llm.providers.platform"
165+
try:
166+
platform_module = importlib.import_module(platform_module_path)
167+
except ImportError as e:
168+
msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing AnyLLM.get_supported_providers()"
169+
raise ImportError(msg) from e
170+
171+
platform_class: type[AnyLLM] = getattr(platform_module, platform_class_name)
172+
173+
# Instantiate the class first and pass the provider next,
174+
# so we don't change the common API between different provideers.
175+
platform_provider = platform_class(api_key=api_key, api_base=api_base, **kwargs)
176+
platform_provider.provider = provider_class # type: ignore[attr-defined]
177+
except ValueError:
178+
pass
179+
else:
180+
return platform_provider
181+
150182
return provider_class(api_key=api_key, api_base=api_base, **kwargs)
151183

152184
@classmethod

src/any_llm/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class LLMProvider(StrEnum):
3939
OLLAMA = "ollama"
4040
OPENAI = "openai"
4141
OPENROUTER = "openrouter"
42+
PLATFORM = "platform"
4243
PORTKEY = "portkey"
4344
SAMBANOVA = "sambanova"
4445
SAGEMAKER = "sagemaker"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .platform import PlatformProvider
2+
from .utils import post_completion_usage_event
3+
4+
__all__ = ["PlatformProvider", "post_completion_usage_event"]
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, cast
4+
5+
from httpx import AsyncClient
6+
7+
from any_llm.any_llm import AnyLLM
8+
from any_llm.logging import logger
9+
from any_llm.types.completion import (
10+
ChatCompletion,
11+
ChatCompletionChunk,
12+
CompletionParams,
13+
CompletionUsage,
14+
CreateEmbeddingResponse,
15+
)
16+
17+
from .utils import get_provider_key, post_completion_usage_event
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import AsyncIterator, Sequence
21+
22+
from any_llm.types.model import Model
23+
24+
25+
class PlatformProvider(AnyLLM):
26+
PROVIDER_NAME = "platform"
27+
ENV_API_KEY_NAME = "ANY_LLM_KEY"
28+
PROVIDER_DOCUMENTATION_URL = "https://github.com/mozilla-ai/any-llm"
29+
30+
# All features are marked as supported, but depending on which provider you call inside the gateway, they may not all work.
31+
SUPPORTS_COMPLETION_STREAMING = True
32+
SUPPORTS_COMPLETION = True
33+
SUPPORTS_RESPONSES = True
34+
SUPPORTS_COMPLETION_REASONING = True
35+
SUPPORTS_COMPLETION_IMAGE = True
36+
SUPPORTS_COMPLETION_PDF = True
37+
SUPPORTS_EMBEDDING = True
38+
SUPPORTS_LIST_MODELS = True
39+
SUPPORTS_BATCH = True
40+
41+
def __init__(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any):
42+
self.any_llm_key = self._verify_and_set_api_key(api_key)
43+
self.api_base = api_base
44+
self.kwargs = kwargs
45+
46+
self._init_client(api_key=api_key, api_base=api_base, **kwargs)
47+
48+
def _init_client(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None:
49+
self.client = AsyncClient(**kwargs)
50+
51+
@staticmethod
52+
def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]:
53+
raise NotImplementedError
54+
55+
@staticmethod
56+
def _convert_completion_response(response: Any) -> ChatCompletion:
57+
raise NotImplementedError
58+
59+
@staticmethod
60+
def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk:
61+
raise NotImplementedError
62+
63+
@staticmethod
64+
def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]:
65+
raise NotImplementedError
66+
67+
@staticmethod
68+
def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse:
69+
raise NotImplementedError
70+
71+
@staticmethod
72+
def _convert_list_models_response(response: Any) -> Sequence[Model]:
73+
raise NotImplementedError
74+
75+
async def _acompletion(
76+
self,
77+
params: CompletionParams,
78+
**kwargs: Any,
79+
) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]:
80+
completion = await self.provider._acompletion(params=params, **kwargs)
81+
82+
if not params.stream:
83+
await post_completion_usage_event(
84+
client=self.client,
85+
any_llm_key=self.any_llm_key, # type: ignore[arg-type]
86+
provider=self.provider.PROVIDER_NAME,
87+
completion=cast("ChatCompletion", completion),
88+
)
89+
return completion
90+
91+
# For streaming, wrap the iterator to collect usage info
92+
return self._stream_with_usage_tracking(cast("AsyncIterator[ChatCompletionChunk]", completion))
93+
94+
async def _stream_with_usage_tracking(
95+
self, stream: AsyncIterator[ChatCompletionChunk]
96+
) -> AsyncIterator[ChatCompletionChunk]:
97+
"""Wrap the stream to track usage after completion."""
98+
chunks: list[ChatCompletionChunk] = []
99+
100+
async for chunk in stream:
101+
chunks.append(chunk)
102+
yield chunk
103+
104+
# After stream completes, reconstruct completion for usage tracking
105+
if chunks:
106+
# Combine chunks into a single ChatCompletion-like object
107+
final_completion = self._combine_chunks(chunks)
108+
await post_completion_usage_event(
109+
client=self.client,
110+
any_llm_key=self.any_llm_key, # type: ignore [arg-type]
111+
provider=self.provider.PROVIDER_NAME,
112+
completion=final_completion,
113+
)
114+
115+
def _combine_chunks(self, chunks: list[ChatCompletionChunk]) -> ChatCompletion:
116+
"""Combine streaming chunks into a ChatCompletion for usage tracking."""
117+
# Get the last chunk which typically has the full usage info
118+
last_chunk = chunks[-1]
119+
120+
if not last_chunk.usage:
121+
msg = (
122+
"The last chunk of your streaming response does not contain usage data. "
123+
"Consult your provider documentation on how to retrieve it."
124+
)
125+
logger.error(msg)
126+
127+
return ChatCompletion(
128+
id=last_chunk.id,
129+
model=last_chunk.model,
130+
created=last_chunk.created,
131+
object="chat.completion",
132+
usage=CompletionUsage(
133+
completion_tokens=0,
134+
prompt_tokens=0,
135+
total_tokens=0,
136+
),
137+
choices=[],
138+
)
139+
140+
# Create a minimal ChatCompletion object with the data needed for usage tracking
141+
# We only need id, model, created, usage, and object type
142+
return ChatCompletion(
143+
id=last_chunk.id,
144+
model=last_chunk.model,
145+
created=last_chunk.created,
146+
object="chat.completion",
147+
usage=last_chunk.usage if hasattr(last_chunk, "usage") and last_chunk.usage else None,
148+
choices=[],
149+
)
150+
151+
@property
152+
def provider(self) -> AnyLLM:
153+
return self._provider
154+
155+
@provider.setter
156+
def provider(self, provider_class: type[AnyLLM]) -> None:
157+
provider_key = get_provider_key(any_llm_key=self.any_llm_key, provider=provider_class) # type: ignore[arg-type]
158+
self._provider = provider_class(api_key=provider_key, api_base=self.api_base, **self.kwargs)

0 commit comments

Comments
 (0)