Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pydantic>=1.10.0
playwright>=1.42.1
requests>=2.31.0
rich
browserbase
browserbase
litellm==1.67.1
8 changes: 8 additions & 0 deletions stagehand/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import StagehandBase
from .config import StagehandConfig
from .context import StagehandContext
from .llm import LLMClient
from .page import StagehandPage
from .utils import StagehandLogger, convert_dict_keys_to_camel_case, default_log_handler

Expand Down Expand Up @@ -101,6 +102,13 @@ def __init__(
verbose=self.verbose, external_logger=on_log, use_rich=use_rich_logging
)

# Instantiate the LLM client
self.llm = LLMClient(
api_key=self.model_api_key,
default_model=self.model_name,
**self.model_client_options,
)

self.httpx_client = httpx_client
self.timeout_settings = timeout_settings or httpx.Timeout(
connect=180.0,
Expand Down
1 change: 1 addition & 0 deletions stagehand/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .client import LLMClient
110 changes: 110 additions & 0 deletions stagehand/llm/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
from typing import Any, Optional

import litellm

# Configure logger for the module
logger = logging.getLogger(__name__)


class LLMClient:
"""
Client for making LLM calls using the litellm library.
Provides a simplified interface for chat completions.
"""

def __init__(
self,
api_key: Optional[str] = None,
default_model: Optional[str] = None,
**kwargs: Any, # To catch other potential litellm global settings
):
"""
Initializes the LiteLLMClient.

Args:
api_key: An API key for the default provider, if required.
It's often better to set provider-specific environment variables
(e.g., OPENAI_API_KEY, ANTHROPIC_API_KEY) which litellm reads automatically.
Passing api_key here might set litellm.api_key globally, which may
not be desired if using multiple providers.
default_model: The default model to use if none is specified in chat_completion
(e.g., "gpt-4o", "claude-3-opus-20240229").
**kwargs: Additional global settings for litellm (e.g., api_base).
See litellm documentation for available settings.
"""
self.default_model = default_model

# Warning:Prefer environment variables for specific providers.
if api_key:
litellm.api_key = api_key
logger.warning(
"Set global litellm.api_key. Prefer provider-specific environment variables."
)

# Apply other global settings if provided
for key, value in kwargs.items():
if hasattr(litellm, key):
setattr(litellm, key, value)
logger.debug(f"Set global litellm.{key}")
# Handle common aliases or expected config names if necessary
elif key == "api_base": # Example: map api_base if needed
litellm.api_base = value
logger.debug(f"Set global litellm.api_base to {value}")

def create_response(
self,
*,
messages: list[dict[str, str]],
model: Optional[str] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
Generate a chat completion response using litellm.

Args:
messages: A list of message dictionaries, e.g., [{"role": "user", "content": "Hello"}].
model: The specific model to use (e.g., "gpt-4o", "claude-3-opus-20240229").
Overrides the default_model if provided.
**kwargs: Additional parameters to pass directly to litellm.completion
(e.g., temperature, max_tokens, stream=True, specific provider arguments).

Returns:
A dictionary containing the completion response from litellm, typically
including choices, usage statistics, etc. Structure depends on the model
provider and whether streaming is used.

Raises:
ValueError: If no model is specified (neither default nor in the call).
Exception: Propagates exceptions from litellm.completion.
"""
completion_model = model or self.default_model
if not completion_model:
raise ValueError(
"No model specified for chat completion (neither default_model nor model argument)."
)

# Prepare arguments directly from kwargs
params = {
"model": completion_model,
"messages": messages,
**kwargs, # Pass through any extra arguments
}
# Filter out None values only for keys explicitly present in kwargs to avoid sending nulls
# unless they were intentionally provided as None.
filtered_params = {
k: v for k, v in params.items() if v is not None or k in kwargs
}

logger.debug(
f"Calling litellm.completion with model={completion_model} and params: {filtered_params}"
)
try:
# Use litellm's completion function
response = litellm.completion(**filtered_params)
return response

except Exception as e:
logger.error(f"Error calling litellm.completion: {e}", exc_info=True)
# Consider more specific exception handling based on litellm errors
raise
8 changes: 7 additions & 1 deletion stagehand/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..base import StagehandBase
from ..config import StagehandConfig
from ..llm.client import LLMClient
from ..utils import StagehandLogger, convert_dict_keys_to_camel_case, sync_log_handler
from .agent import SyncAgent
from .context import SyncStagehandContext
Expand Down Expand Up @@ -74,7 +75,12 @@ def __init__(
# self.context: Optional[SyncStagehandContext] = None
self.agent = None
self.model_client_options = model_client_options
self.streamed_response = True # Default to True for streamed responses
self.streamed_response = True # Default to True for streamed response
self.llm = LLMClient(
api_key=self.model_api_key,
default_model=self.model_name,
**self.model_client_options,
)

def init(self):
"""
Expand Down