Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A unified Python library for inference across multiple Large Language Model prov

## Features

- **Multi-provider support**: Unified API for OpenAI, AWS Bedrock (Claude, Llama, Cohere, Qwen), and Azure OpenAI
- **Multi-provider support**: Unified API for OpenAI, AWS Bedrock (Claude, Llama, Cohere, Qwen), Azure OpenAI, and local OpenAI-compatible endpoints (vLLM, Ollama, LM Studio)
- **Response caching**: DuckDB-based caching to avoid redundant API calls and reduce costs
- **Batch processing**: Async batch inference with configurable concurrency
- **Structured output**: Pydantic model validation for enforcing response schemas
Expand Down Expand Up @@ -111,10 +111,49 @@ load_dotenv()
- `Cohere.COMMAND_R` - Command R
- `Qwen.QWEN_3_235` - Qwen 3 235B

### Local OpenAI-Compatible Endpoints
- `LocalOpenAi.LOCAL` - Any model served via an OpenAI-compatible API (vLLM, Ollama, LM Studio, etc.)

For the complete list, see [lab_llm/constants.py](lab_llm/constants.py).

## Usage Examples

### Using a Local Model (vLLM, Ollama, etc.)

```python
from lab_llm.constants import LocalOpenAi, LLMModel

model = LLMModel(name=LocalOpenAi.LOCAL)
api = LLMApi(
cache=cache,
seed=42,
model_type=model,
error_handler=error_handler,
logging=logger,
base_url="http://localhost:8000/v1", # Your local server URL
local_model_name="Qwen/Qwen3-32B", # Model name on the server
)

response = api.get_output("What is the capital of France?")
```

If your local server does not support structured output (e.g., no guided decoding), disable it:

```python
api = LLMApi(
...,
native_structured_output=False, # Falls back to manual JSON parsing
)
```

You can also configure the endpoint via environment variables instead of constructor args:

```bash
# .env
LOCAL_LLM_BASE_URL=http://localhost:8000/v1
LOCAL_LLM_API_KEY=not-needed # optional, defaults to "not-needed"
```

### Using Structured Output (Pydantic)

```python
Expand Down
4 changes: 4 additions & 0 deletions lab_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
Meta,
Cohere,
Qwen,
LocalOpenAi,
REASONING_MODELS,
is_reasoning_model,
is_local_openai,
list_available_models,
parse_model_string,
)
Expand All @@ -56,9 +58,11 @@
"Meta",
"Cohere",
"Qwen",
"LocalOpenAi",
# Reasoning model utilities
"REASONING_MODELS",
"is_reasoning_model",
"is_local_openai",
# Model helpers
"list_available_models",
"parse_model_string",
Expand Down
11 changes: 9 additions & 2 deletions lab_llm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ class Qwen(str, Enum):
QWEN_3_235 = "qwen.qwen3-235b-a22b-2507-v1:0"


class LocalOpenAi(str, Enum):
"""Sentinel enum for models served via a local OpenAI-compatible endpoint.
The actual model name is passed separately via local_model_name."""
LOCAL = "local"


BedrockModels = Union[Cohere, Claude, Meta, Qwen]
Models = Union[VersaOpenAi, BedrockModels, OpenAi]
Models = Union[VersaOpenAi, BedrockModels, OpenAi, LocalOpenAi]


class LLMModel(BaseModel):
Expand All @@ -77,6 +83,7 @@ class LLMModel(BaseModel):
is_meta = lambda x: isinstance(x, Meta)
is_versa = lambda x: isinstance(x, VersaOpenAi)
is_openai = lambda x: isinstance(x, OpenAi)
is_local_openai = lambda x: isinstance(x, LocalOpenAi)

# Reasoning models support reasoning_effort and verbosity parameters
REASONING_MODELS = {
Expand All @@ -92,7 +99,7 @@ class LLMModel(BaseModel):

def get_all_model_enums():
"""Get all model enum classes."""
return [VersaOpenAi, OpenAi, Cohere, Claude, Meta, Qwen]
return [VersaOpenAi, OpenAi, Cohere, Claude, Meta, Qwen, LocalOpenAi]


def list_available_models() -> list:
Expand Down
39 changes: 31 additions & 8 deletions lab_llm/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(
verbosity: str = "medium",
timeout: int = 60,
return_exceptions: bool = False,
base_url: str = None,
local_model_name: str = None,
native_structured_output: bool = True,
):
super().__init__(seed, model_type, logging)
self.cache = cache
Expand All @@ -60,13 +63,23 @@ def __init__(
self.return_exceptions = return_exceptions
self.reasoning_effort = reasoning_effort
self.verbosity = verbosity
self.base_url = base_url
self.local_model_name = local_model_name
self.native_structured_output = native_structured_output

def _get_cache_reasoning_params(self):
"""Only include reasoning params in cache key for models that use them."""
if constants.is_reasoning_model(self.model_type.name):
return {"reasoning_effort": self.reasoning_effort, "verbosity": self.verbosity}
return {"reasoning_effort": None, "verbosity": None}

def _supports_native_structured_output(self):
if constants.is_meta(self.model_type.name):
return False
if constants.is_local_openai(self.model_type.name):
return self.native_structured_output
return True

def _serialize_llm_response(self, llm_response, response_model: BaseModel = None):
try:
if response_model is not None:
Expand Down Expand Up @@ -135,6 +148,20 @@ def get_client(self, max_new_tokens=4000, temperature=0, requests_per_second=Non
#kwargs["verbosity"] = self.verbosity
kwargs["model_kwargs"] = {"verbosity": self.verbosity}
return AzureChatOpenAI(**kwargs)
elif constants.is_local_openai(self.model_type.name):
return ChatOpenAI(
api_key=os.getenv("LOCAL_LLM_API_KEY", "not-needed"),
base_url=self.base_url or os.getenv(
"LOCAL_LLM_BASE_URL", "http://localhost:8000/v1"
),
model=self.local_model_name or "default",
max_tokens=max_new_tokens,
temperature=temperature,
seed=self.seed,
timeout=self.timeout,
rate_limiter=rate_limiter,
callbacks=[self.error_handler],
)
elif constants.is_bedrock(self.model_type.name):
access_key = os.getenv("BEDROCK_ACCESS_KEY")
secret_access_key = os.getenv("BEDROCK_ACCESS_KEY_SECRET")
Expand Down Expand Up @@ -192,16 +219,14 @@ def get_output(
else: # Not found in cache
self.logging.info("Cache miss")
llm = self.get_client(max_new_tokens, temperature)
if (response_model is not None) and (
not constants.is_meta(self.model_type.name)
):
if (response_model is not None) and self._supports_native_structured_output():
llm = llm.with_structured_output(response_model)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
llm_response = llm.invoke(messages)
if (response_model is not None) and constants.is_meta(self.model_type.name):
if (response_model is not None) and not self._supports_native_structured_output():
llm_response = response_model.model_validate(
from_json(llm_response.content)
)
Expand Down Expand Up @@ -289,9 +314,7 @@ async def get_outputs(
llm = self.get_client(
max_new_tokens, temperature, requests_per_second
)
if (response_model is not None) and (
not constants.is_meta(self.model_type.name)
):
if (response_model is not None) and self._supports_native_structured_output():
llm = llm.with_structured_output(response_model)

batch_results = await self._run_batch(
Expand All @@ -301,7 +324,7 @@ async def get_outputs(
temperature,
response_model=(
response_model
if not constants.is_meta(self.model_type.name)
if self._supports_native_structured_output()
else None
),
prompt_cache_key=prompt_cache_key,
Expand Down
Loading