-
Notifications
You must be signed in to change notification settings - Fork 204
Add baseten integration #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
"""Baseten model provider. | ||
- Docs: https://docs.baseten.co/ | ||
""" | ||
|
||
import logging | ||
from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast | ||
|
||
import openai | ||
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion | ||
from pydantic import BaseModel | ||
from typing_extensions import Unpack, override | ||
|
||
from ..types.content import Messages | ||
from ..types.models import OpenAIModel | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
class Client(Protocol): | ||
"""Protocol defining the OpenAI-compatible interface for the underlying provider client.""" | ||
|
||
@property | ||
# pragma: no cover | ||
def chat(self) -> Any: | ||
"""Chat completions interface.""" | ||
... | ||
|
||
|
||
class BasetenModel(OpenAIModel): | ||
"""Baseten model provider implementation.""" | ||
|
||
client: Client | ||
|
||
class BasetenConfig(TypedDict, total=False): | ||
"""Configuration options for Baseten models. | ||
Attributes: | ||
model_id: Model ID for the Baseten model. | ||
For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct". | ||
For dedicated deployments, use the deployment ID. | ||
base_url: Base URL for the Baseten API. | ||
For Model APIs: https://inference.baseten.co/v1 | ||
For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1 | ||
params: Model parameters (e.g., max_tokens). | ||
For a complete list of supported parameters, see | ||
https://platform.openai.com/docs/api-reference/chat/create. | ||
""" | ||
|
||
model_id: str | ||
base_url: Optional[str] | ||
params: Optional[dict[str, Any]] | ||
|
||
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[BasetenConfig]) -> None: | ||
"""Initialize provider instance. | ||
Args: | ||
client_args: Arguments for the Baseten client. | ||
For a complete list of supported arguments, see https://pypi.org/project/openai/. | ||
**model_config: Configuration options for the Baseten model. | ||
""" | ||
self.config = dict(model_config) | ||
|
||
logger.debug("config=<%s> | initializing", self.config) | ||
|
||
client_args = client_args or {} | ||
|
||
# Set default base URL for Model APIs if not provided | ||
if "base_url" not in client_args and "base_url" not in self.config: | ||
client_args["base_url"] = "https://inference.baseten.co/v1" | ||
elif "base_url" in self.config: | ||
client_args["base_url"] = self.config["base_url"] | ||
|
||
self.client = openai.OpenAI(**client_args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've migrated to AsyncOpenAI in our implementation. Please verify this change is properly reflected throughout the codebase in your PR. Also, ensure you've pulled the most recent code before proceeding with your review. |
||
|
||
@override | ||
def update_config(self, **model_config: Unpack[BasetenConfig]) -> None: # type: ignore[override] | ||
"""Update the Baseten model configuration with the provided arguments. | ||
Args: | ||
**model_config: Configuration overrides. | ||
""" | ||
self.config.update(model_config) | ||
|
||
@override | ||
def get_config(self) -> BasetenConfig: | ||
"""Get the Baseten model configuration. | ||
Returns: | ||
The Baseten model configuration. | ||
""" | ||
return cast(BasetenModel.BasetenConfig, self.config) | ||
|
||
@override | ||
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here. |
||
"""Send the request to the Baseten model and get the streaming response. | ||
Args: | ||
request: The formatted request to send to the Baseten model. | ||
Returns: | ||
An iterable of response events from the Baseten model. | ||
""" | ||
response = self.client.chat.completions.create(**request) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. async happens also here ^^ |
||
|
||
yield {"chunk_type": "message_start"} | ||
yield {"chunk_type": "content_start", "data_type": "text"} | ||
|
||
tool_calls: dict[int, list[Any]] = {} | ||
|
||
for event in response: | ||
# Defensive: skip events with empty or missing choices | ||
if not getattr(event, "choices", None): | ||
continue | ||
choice = event.choices[0] | ||
|
||
if choice.delta.content: | ||
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} | ||
|
||
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: | ||
yield { | ||
"chunk_type": "content_delta", | ||
"data_type": "reasoning_content", | ||
"data": choice.delta.reasoning_content, | ||
} | ||
|
||
for tool_call in choice.delta.tool_calls or []: | ||
tool_calls.setdefault(tool_call.index, []).append(tool_call) | ||
|
||
if choice.finish_reason: | ||
break | ||
|
||
yield {"chunk_type": "content_stop", "data_type": "text"} | ||
|
||
for tool_deltas in tool_calls.values(): | ||
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} | ||
|
||
for tool_delta in tool_deltas: | ||
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} | ||
|
||
yield {"chunk_type": "content_stop", "data_type": "tool"} | ||
|
||
yield {"chunk_type": "message_stop", "data": choice.finish_reason} | ||
|
||
# Skip remaining events as we don't have use for anything except the final usage payload | ||
for event in response: | ||
_ = event | ||
|
||
yield {"chunk_type": "metadata", "data": event.usage} | ||
|
||
@override | ||
def structured_output( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you might want to update async here |
||
self, output_model: Type[T], prompt: Messages | ||
) -> Generator[dict[str, Union[T, Any]], None, None]: | ||
"""Get structured output from the model. | ||
Args: | ||
output_model: The output model to use for the agent. | ||
prompt: The prompt messages to use for the agent. | ||
Yields: | ||
Model events with the last being the structured output. | ||
""" | ||
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore | ||
model=self.get_config()["model_id"], | ||
messages=super().format_request(prompt)["messages"], | ||
response_format=output_model, | ||
) | ||
|
||
parsed: T | None = None | ||
# Find the first choice with tool_calls | ||
if len(response.choices) > 1: | ||
raise ValueError("Multiple choices found in the Baseten response.") | ||
|
||
for choice in response.choices: | ||
if isinstance(choice.message.parsed, output_model): | ||
parsed = choice.message.parsed | ||
break | ||
|
||
if parsed: | ||
yield {"output": parsed} | ||
else: | ||
raise ValueError("No valid tool use or tool use input was found in the Baseten response.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import os | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
import strands | ||
from strands import Agent | ||
from strands.models.baseten import BasetenModel | ||
|
||
|
||
@pytest.fixture | ||
def model_model_apis(): | ||
"""Test with Model APIs using DeepSeek R1 model.""" | ||
return BasetenModel( | ||
model_id="deepseek-ai/DeepSeek-V3-0324", | ||
client_args={ | ||
"api_key": os.getenv("BASETEN_API_KEY"), | ||
}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def model_dedicated_deployment(): | ||
"""Test with dedicated deployment -- change this to your deployment ID when testing.""" | ||
base_url = "https://model-232k7g23.api.baseten.co/environments/production/sync/v1" | ||
|
||
return BasetenModel( | ||
base_url=base_url, | ||
client_args={ | ||
"api_key": os.getenv("BASETEN_API_KEY"), | ||
}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def tools(): | ||
@strands.tool | ||
def tool_time() -> str: | ||
return "12:00" | ||
|
||
@strands.tool | ||
def tool_weather() -> str: | ||
return "sunny" | ||
|
||
return [tool_time, tool_weather] | ||
|
||
|
||
@pytest.fixture | ||
def agent_model_apis(model_model_apis, tools): | ||
return Agent(model=model_model_apis, tools=tools) | ||
|
||
|
||
@pytest.fixture | ||
def agent_dedicated(model_dedicated_deployment, tools): | ||
return Agent(model=model_dedicated_deployment, tools=tools) | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ, | ||
reason="BASETEN_API_KEY environment variable missing", | ||
) | ||
def test_agent_model_apis(agent_model_apis): | ||
result = agent_model_apis("What is the time and weather in New York?") | ||
text = result.message["content"][0]["text"].lower() | ||
|
||
assert all(string in text for string in ["12:00", "sunny"]) | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ or "BASETEN_DEPLOYMENT_ID" not in os.environ, | ||
reason="BASETEN_API_KEY or BASETEN_DEPLOYMENT_ID environment variable missing", | ||
) | ||
def test_agent_dedicated_deployment(agent_dedicated): | ||
result = agent_dedicated("What is the time and weather in New York?") | ||
text = result.message["content"][0]["text"].lower() | ||
|
||
assert all(string in text for string in ["12:00", "sunny"]) | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ, | ||
reason="BASETEN_API_KEY environment variable missing", | ||
) | ||
def test_structured_output_model_apis(model_model_apis): | ||
class Weather(BaseModel): | ||
"""Extracts the time and weather from the user's message with the exact strings.""" | ||
|
||
time: str | ||
weather: str | ||
|
||
agent = Agent(model=model_model_apis) | ||
|
||
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") | ||
assert isinstance(result, Weather) | ||
assert result.time == "12:00" | ||
assert result.weather == "sunny" | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ or "BASETEN_DEPLOYMENT_ID" not in os.environ, | ||
reason="BASETEN_API_KEY or BASETEN_DEPLOYMENT_ID environment variable missing", | ||
) | ||
def test_structured_output_dedicated_deployment(model_dedicated_deployment): | ||
class Weather(BaseModel): | ||
"""Extracts the time and weather from the user's message with the exact strings.""" | ||
|
||
time: str | ||
weather: str | ||
|
||
agent = Agent(model=model_dedicated_deployment) | ||
|
||
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") | ||
assert isinstance(result, Weather) | ||
assert result.time == "12:00" | ||
assert result.weather == "sunny" | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ, | ||
reason="BASETEN_API_KEY environment variable missing", | ||
) | ||
def test_llama_model_model_apis(): | ||
"""Test with Llama 4 Maverick model on Model APIs.""" | ||
model = BasetenModel( | ||
model_id="meta-llama/Llama-4-Maverick-17B-128E-Instruct", | ||
client_args={ | ||
"api_key": os.getenv("BASETEN_API_KEY"), | ||
}, | ||
) | ||
|
||
agent = Agent(model=model) | ||
result = agent("Hello, how are you?") | ||
|
||
assert result.message["content"][0]["text"] is not None | ||
assert len(result.message["content"][0]["text"]) > 0 | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ, | ||
reason="BASETEN_API_KEY environment variable missing", | ||
) | ||
def test_deepseek_r1_model_apis(): | ||
"""Test with DeepSeek R1 model on Model APIs.""" | ||
model = BasetenModel( | ||
model_id="deepseek-ai/DeepSeek-R1-0528", | ||
client_args={ | ||
"api_key": os.getenv("BASETEN_API_KEY"), | ||
}, | ||
) | ||
|
||
agent = Agent(model=model) | ||
result = agent("What is 2 + 2?") | ||
|
||
assert result.message["content"][0]["text"] is not None | ||
assert len(result.message["content"][0]["text"]) > 0 | ||
|
||
|
||
@pytest.mark.skipif( | ||
"BASETEN_API_KEY" not in os.environ, | ||
reason="BASETEN_API_KEY environment variable missing", | ||
) | ||
def test_llama_scout_model_apis(): | ||
"""Test with Llama 4 Scout model on Model APIs.""" | ||
model = BasetenModel( | ||
model_id="meta-llama/Llama-4-Scout-17B-16E-Instruct", | ||
client_args={ | ||
"api_key": os.getenv("BASETEN_API_KEY"), | ||
}, | ||
) | ||
|
||
agent = Agent(model=model) | ||
result = agent("Explain quantum computing in simple terms.") | ||
|
||
assert result.message["content"][0]["text"] is not None | ||
assert len(result.message["content"][0]["text"]) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import path is out of date.