Skip to content

feat: add structured output support using Pydantic models #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e183907
feat: add structured output support using Pydantic models
theagenticguy May 20, 2025
03942ae
fix: import cleanups and unused vars
theagenticguy May 20, 2025
19a580d
Merge branch 'main' into feature/structured-output
theagenticguy Jun 5, 2025
510def6
feat: wip adding `structured_output` methods
theagenticguy Jun 5, 2025
c3ffbce
feat: wip added structured output to bedrock and anthropic
theagenticguy Jun 5, 2025
0f03889
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 5, 2025
dce0a81
feat: litellm structured output and some integ tests
theagenticguy Jun 7, 2025
5262dfc
feat: all structured outputs working, tbd llama api
theagenticguy Jun 8, 2025
2a1f5ed
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 8, 2025
23df2c6
feat: updated docstring
theagenticguy Jun 8, 2025
cc78b6f
fix: otel ci dep issue
theagenticguy Jun 8, 2025
e8ef600
fix: remove unnecessary changes and comments
theagenticguy Jun 9, 2025
6eeeaa8
feat: basic test WIP
theagenticguy Jun 9, 2025
51f1f1d
feat: better test coverage
theagenticguy Jun 9, 2025
d5bef96
fix: remove unused fixture
theagenticguy Jun 9, 2025
c66fa32
fix: resolve some comments
theagenticguy Jun 13, 2025
422bc25
fix: inline basemodel classes
theagenticguy Jun 13, 2025
eabf075
feat: update litellm, add checks
theagenticguy Jun 17, 2025
7194d6c
Merge branch 'main' into feature/structured-output
theagenticguy Jun 17, 2025
885d3ac
fix: autoformatting issue
theagenticguy Jun 17, 2025
7308491
feat: resolves comments
theagenticguy Jun 17, 2025
a88c93b
Merge branch 'main' into feature/structured-output
theagenticguy Jun 17, 2025
0216bcc
fix: ollama skip tests, pyproject whitespace diffs
theagenticguy Jun 18, 2025
49ccfb5
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ docs = [
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
]
litellm = [
"litellm>=1.69.0,<2.0.0",
"litellm>=1.72.6,<2.0.0",
]
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
Expand Down Expand Up @@ -264,4 +264,4 @@ style = [
["instruction", ""],
["text", ""],
["disabled", "fg:#858585 italic"]
]
]
32 changes: 31 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import random
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from uuid import uuid4

from opentelemetry import trace
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
Expand All @@ -43,6 +44,9 @@

logger = logging.getLogger(__name__)

# TypeVar for generic structured output
T = TypeVar("T", bound=BaseModel)


# Sentinel class and object to distinguish between explicit None and default parameter value
class _DefaultCallbackHandlerSentinel:
Expand Down Expand Up @@ -387,6 +391,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
# Re-raise the exception to preserve original behavior
raise

def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
If you don't pass in a prompt, it will use only the conversation history to respond.
If no conversation history exists and no prompt is provided, an error will be raised.

For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
instruct the model to output the structured data.

Args:
output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
prompt(Optional[str]): The prompt to use for the agent.
"""
messages = self.messages
if not messages and not prompt:
raise ValueError("No conversation history or prompt provided")

# add the prompt as the last message
if prompt:
messages.append({"role": "user", "content": [{"text": prompt}]})

# get the structured output from the model
return self.model.structured_output(output_model, messages, self.callback_handler)

async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.

Expand Down
51 changes: 48 additions & 3 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict, cast
from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast

import anthropic
from pydantic import BaseModel
from typing_extensions import Required, Unpack, override

from ..event_loop.streaming import process_stream
from ..handlers.callback_handler import PrintingCallbackHandler
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.models import Model
Expand All @@ -20,6 +24,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class AnthropicModel(Model):
"""Anthropic model provider implementation."""
Expand Down Expand Up @@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
with self.client.messages.stream(**request) as stream:
for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield event.dict()
yield event.model_dump()

usage = event.message.usage # type: ignore
yield {"type": "metadata", "usage": usage.dict()}
yield {"type": "metadata", "usage": usage.model_dump()}

except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
Expand All @@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
raise ContextWindowOverflowException(str(error)) from error

raise error

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
# process the stream and get the tool use input
results = process_stream(
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
)

stop_reason, messages, _, _, _ = results

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")

content = messages["content"]
output_response: dict[str, Any] | None = None
for block in content:
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
# if the tool use name never matches, raise an error.
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
output_response = block["toolUse"]["input"]
else:
continue

if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")

return output_model(**output_response)
47 changes: 46 additions & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import json
import logging
import os
from typing import Any, Iterable, List, Literal, Optional, cast
from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast

import boto3
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from ..event_loop.streaming import process_stream
from ..handlers.callback_handler import PrintingCallbackHandler
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.models import Model
Expand All @@ -29,6 +33,8 @@
"too many total text bytes",
]

T = TypeVar("T", bound=BaseModel)


class BedrockModel(Model):
"""AWS Bedrock model provider implementation.
Expand Down Expand Up @@ -477,3 +483,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
return self._find_detected_and_blocked_policy(item)
# Otherwise return False
return False

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
# process the stream and get the tool use input
results = process_stream(
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
)

stop_reason, messages, _, _, _ = results

if stop_reason != "tool_use":
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

content = messages["content"]
output_response: dict[str, Any] | None = None
for block in content:
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
# if the tool use name never matches, raise an error.
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
output_response = block["toolUse"]["input"]
else:
continue

if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

return output_model(**output_response)
49 changes: 47 additions & 2 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
- Docs: https://docs.litellm.ai/
"""

import json
import logging
from typing import Any, Optional, TypedDict, cast
from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast

import litellm
from litellm.utils import supports_response_schema
from pydantic import BaseModel
from typing_extensions import Unpack, override

from ..types.content import ContentBlock
from ..types.content import ContentBlock, Messages
from .openai import OpenAIModel

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class LiteLLMModel(OpenAIModel):
"""LiteLLM model provider implementation."""
Expand Down Expand Up @@ -97,3 +102,43 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
}

return super().format_request_message_content(content)

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

"""
# The LiteLLM `Client` inits with Chat().
# Chat() inits with self.completions
# completions() has a method `create()` which wraps the real completion API of Litellm
response = self.client.chat.completions.create(
model=self.get_config()["model_id"],
messages=super().format_request(prompt)["messages"],
response_format=output_model,
)

if not supports_response_schema(self.get_config()["model_id"]):
raise ValueError("Model does not support response_format")
if len(response.choices) > 1:
raise ValueError("Multiple choices found in the response.")

# Find the first choice with tool_calls
for choice in response.choices:
if choice.finish_reason == "tool_calls":
try:
# Parse the tool call content as JSON
tool_call_data = json.loads(choice.message.content)
# Instantiate the output model with the parsed data
return output_model(**tool_call_data)
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e

# If no tool_calls found, raise an error
raise ValueError("No tool_calls found in response")
33 changes: 32 additions & 1 deletion src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, cast
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast

import llama_api_client
from llama_api_client import LlamaAPIClient
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
Expand All @@ -22,6 +23,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class LlamaAPIModel(Model):
"""Llama API model provider implementation."""
Expand Down Expand Up @@ -384,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
# we may have a metrics event here
if metrics_event:
yield {"chunk_type": "metadata", "data": metrics_event}

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.

Raises:
NotImplementedError: Structured output is not currently supported for LlamaAPI models.
"""
# response_format: ResponseFormat = {
# "type": "json_schema",
# "json_schema": {
# "name": output_model.__name__,
# "schema": output_model.model_json_schema(),
# },
# }
# response = self.client.chat.completions.create(
# model=self.config["model_id"],
# messages=self.format_request(prompt)["messages"],
# response_format=response_format,
# )
raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.")
27 changes: 26 additions & 1 deletion src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import json
import logging
from typing import Any, Iterable, Optional, cast
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast

from ollama import Client as OllamaClient
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
Expand All @@ -17,6 +18,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class OllamaModel(Model):
"""Ollama model provider implementation.
Expand Down Expand Up @@ -310,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
yield {"chunk_type": "content_stop", "data_type": "text"}
yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason}
yield {"chunk_type": "metadata", "data": event}

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Messages): The prompt messages to use for the agent.
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
"""
formatted_request = self.format_request(messages=prompt)
formatted_request["format"] = output_model.model_json_schema()
formatted_request["stream"] = False
response = self.client.chat(**formatted_request)

try:
content = response.message.content.strip()
return output_model.model_validate_json(content)
except Exception as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e
Loading