Skip to content

Add baseten integration #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ packages = ["src/strands"]
anthropic = [
"anthropic>=0.21.0,<1.0.0",
]
baseten = [
"openai>=1.68.0,<2.0.0",
]
dev = [
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
Expand Down
185 changes: 185 additions & 0 deletions src/strands/models/baseten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Baseten model provider.
- Docs: https://docs.baseten.co/
"""

import logging
from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast

import openai
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
from pydantic import BaseModel
from typing_extensions import Unpack, override

from ..types.content import Messages
from ..types.models import OpenAIModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import path is out of date.


logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class Client(Protocol):
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""

@property
# pragma: no cover
def chat(self) -> Any:
"""Chat completions interface."""
...


class BasetenModel(OpenAIModel):
"""Baseten model provider implementation."""

client: Client

class BasetenConfig(TypedDict, total=False):
"""Configuration options for Baseten models.
Attributes:
model_id: Model ID for the Baseten model.
For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct".
For dedicated deployments, use the deployment ID.
base_url: Base URL for the Baseten API.
For Model APIs: https://inference.baseten.co/v1
For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1
params: Model parameters (e.g., max_tokens).
For a complete list of supported parameters, see
https://platform.openai.com/docs/api-reference/chat/create.
"""

model_id: str
base_url: Optional[str]
params: Optional[dict[str, Any]]

def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[BasetenConfig]) -> None:
"""Initialize provider instance.
Args:
client_args: Arguments for the Baseten client.
For a complete list of supported arguments, see https://pypi.org/project/openai/.
**model_config: Configuration options for the Baseten model.
"""
self.config = dict(model_config)

logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}

# Set default base URL for Model APIs if not provided
if "base_url" not in client_args and "base_url" not in self.config:
client_args["base_url"] = "https://inference.baseten.co/v1"
elif "base_url" in self.config:
client_args["base_url"] = self.config["base_url"]

self.client = openai.OpenAI(**client_args)
Copy link
Contributor

@JackYPCOnline JackYPCOnline Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've migrated to AsyncOpenAI in our implementation. Please verify this change is properly reflected throughout the codebase in your PR. Also, ensure you've pulled the most recent code before proceeding with your review.


@override
def update_config(self, **model_config: Unpack[BasetenConfig]) -> None: # type: ignore[override]
"""Update the Baseten model configuration with the provided arguments.
Args:
**model_config: Configuration overrides.
"""
self.config.update(model_config)

@override
def get_config(self) -> BasetenConfig:
"""Get the Baseten model configuration.
Returns:
The Baseten model configuration.
"""
return cast(BasetenModel.BasetenConfig, self.config)

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
Copy link
Contributor

@JackYPCOnline JackYPCOnline Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here.

"""Send the request to the Baseten model and get the streaming response.
Args:
request: The formatted request to send to the Baseten model.
Returns:
An iterable of response events from the Baseten model.
"""
response = self.client.chat.completions.create(**request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async happens also here ^^


yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}

tool_calls: dict[int, list[Any]] = {}

for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield {
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
break

yield {"chunk_type": "content_stop", "data_type": "text"}

for tool_deltas in tool_calls.values():
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}

yield {"chunk_type": "content_stop", "data_type": "tool"}

yield {"chunk_type": "message_stop", "data": choice.finish_reason}

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
_ = event

yield {"chunk_type": "metadata", "data": event.usage}

@override
def structured_output(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to update async here

self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.
Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
Yields:
Model events with the last being the structured output.
"""
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
model=self.get_config()["model_id"],
messages=super().format_request(prompt)["messages"],
response_format=output_model,
)

parsed: T | None = None
# Find the first choice with tool_calls
if len(response.choices) > 1:
raise ValueError("Multiple choices found in the Baseten response.")

for choice in response.choices:
if isinstance(choice.message.parsed, output_model):
parsed = choice.message.parsed
break

if parsed:
yield {"output": parsed}
else:
raise ValueError("No valid tool use or tool use input was found in the Baseten response.")
175 changes: 175 additions & 0 deletions tests-integ/test_model_baseten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import os

import pytest
from pydantic import BaseModel

import strands
from strands import Agent
from strands.models.baseten import BasetenModel


@pytest.fixture
def model_model_apis():
"""Test with Model APIs using DeepSeek R1 model."""
return BasetenModel(
model_id="deepseek-ai/DeepSeek-V3-0324",
client_args={
"api_key": os.getenv("BASETEN_API_KEY"),
},
)


@pytest.fixture
def model_dedicated_deployment():
"""Test with dedicated deployment -- change this to your deployment ID when testing."""
base_url = "https://model-232k7g23.api.baseten.co/environments/production/sync/v1"

return BasetenModel(
base_url=base_url,
client_args={
"api_key": os.getenv("BASETEN_API_KEY"),
},
)


@pytest.fixture
def tools():
@strands.tool
def tool_time() -> str:
return "12:00"

@strands.tool
def tool_weather() -> str:
return "sunny"

return [tool_time, tool_weather]


@pytest.fixture
def agent_model_apis(model_model_apis, tools):
return Agent(model=model_model_apis, tools=tools)


@pytest.fixture
def agent_dedicated(model_dedicated_deployment, tools):
return Agent(model=model_dedicated_deployment, tools=tools)


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ,
reason="BASETEN_API_KEY environment variable missing",
)
def test_agent_model_apis(agent_model_apis):
result = agent_model_apis("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ or "BASETEN_DEPLOYMENT_ID" not in os.environ,
reason="BASETEN_API_KEY or BASETEN_DEPLOYMENT_ID environment variable missing",
)
def test_agent_dedicated_deployment(agent_dedicated):
result = agent_dedicated("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ,
reason="BASETEN_API_KEY environment variable missing",
)
def test_structured_output_model_apis(model_model_apis):
class Weather(BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str

agent = Agent(model=model_model_apis)

result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ or "BASETEN_DEPLOYMENT_ID" not in os.environ,
reason="BASETEN_API_KEY or BASETEN_DEPLOYMENT_ID environment variable missing",
)
def test_structured_output_dedicated_deployment(model_dedicated_deployment):
class Weather(BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str

agent = Agent(model=model_dedicated_deployment)

result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ,
reason="BASETEN_API_KEY environment variable missing",
)
def test_llama_model_model_apis():
"""Test with Llama 4 Maverick model on Model APIs."""
model = BasetenModel(
model_id="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
client_args={
"api_key": os.getenv("BASETEN_API_KEY"),
},
)

agent = Agent(model=model)
result = agent("Hello, how are you?")

assert result.message["content"][0]["text"] is not None
assert len(result.message["content"][0]["text"]) > 0


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ,
reason="BASETEN_API_KEY environment variable missing",
)
def test_deepseek_r1_model_apis():
"""Test with DeepSeek R1 model on Model APIs."""
model = BasetenModel(
model_id="deepseek-ai/DeepSeek-R1-0528",
client_args={
"api_key": os.getenv("BASETEN_API_KEY"),
},
)

agent = Agent(model=model)
result = agent("What is 2 + 2?")

assert result.message["content"][0]["text"] is not None
assert len(result.message["content"][0]["text"]) > 0


@pytest.mark.skipif(
"BASETEN_API_KEY" not in os.environ,
reason="BASETEN_API_KEY environment variable missing",
)
def test_llama_scout_model_apis():
"""Test with Llama 4 Scout model on Model APIs."""
model = BasetenModel(
model_id="meta-llama/Llama-4-Scout-17B-16E-Instruct",
client_args={
"api_key": os.getenv("BASETEN_API_KEY"),
},
)

agent = Agent(model=model)
result = agent("Explain quantum computing in simple terms.")

assert result.message["content"][0]["text"] is not None
assert len(result.message["content"][0]["text"]) > 0
Loading