Skip to content

feat: add structured output support using Pydantic models #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e183907
feat: add structured output support using Pydantic models
theagenticguy May 20, 2025
03942ae
fix: import cleanups and unused vars
theagenticguy May 20, 2025
19a580d
Merge branch 'main' into feature/structured-output
theagenticguy Jun 5, 2025
510def6
feat: wip adding `structured_output` methods
theagenticguy Jun 5, 2025
c3ffbce
feat: wip added structured output to bedrock and anthropic
theagenticguy Jun 5, 2025
0f03889
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 5, 2025
dce0a81
feat: litellm structured output and some integ tests
theagenticguy Jun 7, 2025
5262dfc
feat: all structured outputs working, tbd llama api
theagenticguy Jun 8, 2025
2a1f5ed
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 8, 2025
23df2c6
feat: updated docstring
theagenticguy Jun 8, 2025
cc78b6f
fix: otel ci dep issue
theagenticguy Jun 8, 2025
e8ef600
fix: remove unnecessary changes and comments
theagenticguy Jun 9, 2025
6eeeaa8
feat: basic test WIP
theagenticguy Jun 9, 2025
51f1f1d
feat: better test coverage
theagenticguy Jun 9, 2025
d5bef96
fix: remove unused fixture
theagenticguy Jun 9, 2025
c66fa32
fix: resolve some comments
theagenticguy Jun 13, 2025
422bc25
fix: inline basemodel classes
theagenticguy Jun 13, 2025
eabf075
feat: update litellm, add checks
theagenticguy Jun 17, 2025
7194d6c
Merge branch 'main' into feature/structured-output
theagenticguy Jun 17, 2025
885d3ac
fix: autoformatting issue
theagenticguy Jun 17, 2025
7308491
feat: resolves comments
theagenticguy Jun 17, 2025
a88c93b
Merge branch 'main' into feature/structured-output
theagenticguy Jun 17, 2025
0216bcc
fix: ollama skip tests, pyproject whitespace diffs
theagenticguy Jun 18, 2025
49ccfb5
Merge branch 'strands-agents:main' into feature/structured-output
theagenticguy Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: wip adding structured_output methods
  • Loading branch information
theagenticguy committed Jun 5, 2025
commit 510def66a133727a84ef05ab33c0981b95689bb0
139 changes: 71 additions & 68 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
import random
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, Union
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union
from uuid import uuid4

from opentelemetry import trace
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
Expand All @@ -30,12 +29,11 @@
from ..telemetry.tracer import get_tracer
from ..tools.registry import ToolRegistry
from ..tools.thread_pool_executor import ThreadPoolExecutorWrapper
from ..tools.tools import PythonAgentTool
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.models import Model
from ..types.tools import ToolConfig, ToolResult, ToolUse
from ..types.tools import ToolConfig
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -368,70 +366,75 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
# Re-raise the exception to preserve original behavior
raise

def with_output(self, prompt: str, output_model: Type[BaseModel]) -> BaseModel:
"""Set the output model for the agent.

Args:
prompt: The prompt to use for the agent.
output_model: The output model to use for the agent.

Returns: the loaded basemodel
"""
from ..tools.structured_output import convert_pydantic_to_bedrock_tool

# Convert the pydantic basemodel to a tool spec
tool_spec = convert_pydantic_to_bedrock_tool(output_model)

# Create a dynamic tool name to avoid collisions
tool_name = f"generate_{output_model.__name__}"
tool_spec["toolSpec"]["name"] = tool_name

# Register the tool with the tool registry
# We need a special type of tool that just passes through the input

# Create a passthrough callback that just returns the input
# with the signature expected by PythonAgentTool

def output_callback(
tool_use: ToolUse,
model: Any = None, # noqa: ANN401
messages: Optional[dict[str, Any]] = None, # noqa: ANN401
**kwargs: Any,
) -> ToolResult:
# Return the ToolResult explicitly typed
result: ToolResult = {
"toolUseId": tool_use["toolUseId"],
"status": "success",
"content": [{"text": "Output generated successfully"}],
}
return result

tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback)
self.tool_registry.register_tool(tool)

# Call the model with the tool and get the response
# This will run the model and invoke the tool
self(prompt)

# Extract the tool input from the message
# Find the first toolUse in the conversation history
tool_input = None
for message in self.messages:
if message.get("role") == "assistant":
for content in message.get("content", []):
if isinstance(content, dict) and "toolUse" in content:
tool_use = content["toolUse"]
if tool_use.get("name") == tool_name:
tool_input = tool_use.get("input", {})
break
if tool_input:
break

# Create the output model from the tool input and return it
if not tool_input:
raise ValueError(f"Model did not generate a valid {output_model.__name__}")

return output_model(**tool_input)
# TODO: implement
# def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str]) -> BaseModel:
# """Get structured output from the Agent's current context.

# Args:
# output_model(Type[BaseModel]): The output model the agent will use when responding.
# prompt(Optional[str]): The prompt to use for the agent.

# Returns:
# The loaded basemodel.

# Raises:
# ValidationException: The response format from the large language model does not match the output_model
# """
# from ..tools.structured_output import convert_pydantic_to_bedrock_tool

# # Convert the pydantic basemodel to a tool spec
# tool_spec = convert_pydantic_to_bedrock_tool(output_model)

# # Create a dynamic tool name to avoid collisions
# tool_name = f"generate_{output_model.__name__}"
# tool_spec["toolSpec"]["name"] = tool_name

# # Register the tool with the tool registry
# # We need a special type of tool that just passes through the input

# # Create a passthrough callback that just returns the input
# # with the signature expected by PythonAgentTool

# def output_callback(
# tool_use: ToolUse,
# model: Any = None, # noqa: ANN401
# messages: Optional[dict[str, Any]] = None, # noqa: ANN401
# **kwargs: Any,
# ) -> ToolResult:
# # Return the ToolResult explicitly typed
# result: ToolResult = {
# "toolUseId": tool_use["toolUseId"],
# "status": "success",
# "content": [{"text": "Output generated successfully"}],
# }
# return result

# tool = PythonAgentTool(tool_name=tool_name, tool_spec=tool_spec["toolSpec"], callback=output_callback)
# self.tool_registry.register_tool(tool)

# # Call the model with the tool and get the response
# # This will run the model and invoke the tool
# self(prompt)

# # Extract the tool input from the message
# # Find the first toolUse in the conversation history
# tool_input = None
# for message in self.messages:
# if message.get("role") == "assistant":
# for content in message.get("content", []):
# if isinstance(content, dict) and "toolUse" in content:
# tool_use = content["toolUse"]
# if tool_use.get("name") == tool_name:
# tool_input = tool_use.get("input", {})
# break
# if tool_input:
# break

# # Create the output model from the tool input and return it
# if not tool_input:
# raise ValueError(f"Model did not generate a valid {output_model.__name__}")

# return output_model(**tool_input)

async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down
13 changes: 12 additions & 1 deletion src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict, cast
from typing import Any, Iterable, Optional, Type, TypedDict, cast

import anthropic
from pydantic import BaseModel
from typing_extensions import Required, Unpack, override

from ..types.content import ContentBlock, Messages
Expand Down Expand Up @@ -369,3 +370,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
raise ContextWindowOverflowException(str(error)) from error

raise error

@override
def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Optional[str]): The prompt to use for the agent. Defaults to None.
"""
return output_model()
13 changes: 12 additions & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import json
import logging
import os
from typing import Any, Iterable, List, Literal, Optional, cast
from typing import Any, Iterable, List, Literal, Optional, Type, cast

import boto3
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from ..types.content import Messages
Expand Down Expand Up @@ -477,3 +478,13 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
return self._find_detected_and_blocked_policy(item)
# Otherwise return False
return False

@override
def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Optional[str]): The prompt to use for the agent. Defaults to None.
"""
return output_model()
13 changes: 12 additions & 1 deletion src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, cast
from typing import Any, Iterable, Optional, Type, cast

import llama_api_client
from llama_api_client import LlamaAPIClient
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
Expand Down Expand Up @@ -384,3 +385,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
# we may have a metrics event here
if metrics_event:
yield {"chunk_type": "metadata", "data": metrics_event}

@override
def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Optional[str]): The prompt to use for the agent. Defaults to None.
"""
return output_model()
13 changes: 12 additions & 1 deletion src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import json
import logging
from typing import Any, Iterable, Optional, cast
from typing import Any, Iterable, Optional, Type, cast

from ollama import Client as OllamaClient
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
Expand Down Expand Up @@ -310,3 +311,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
yield {"chunk_type": "content_stop", "data_type": "text"}
yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason}
yield {"chunk_type": "metadata", "data": event}

@override
def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Optional[str]): The prompt to use for the agent. Defaults to None.
"""
return output_model()
20 changes: 19 additions & 1 deletion src/strands/types/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import abc
import logging
from typing import Any, Iterable, Optional
from typing import Any, Iterable, Optional, Type

from pydantic import BaseModel

from ..content import Messages
from ..streaming import StreamEvent
Expand Down Expand Up @@ -38,6 +40,22 @@ def get_config(self) -> Any:
"""
pass

@abc.abstractmethod
# pragma: no cover
def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Optional[str]): The prompt to use for the agent. Defaults to None.

Returns:
The structured output as a serialized instance of the output model.

Raises:
ValidationException: The response format from the model does not match the output_model
"""

@abc.abstractmethod
# pragma: no cover
def format_request(
Expand Down
13 changes: 12 additions & 1 deletion src/strands/types/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import json
import logging
import mimetypes
from typing import Any, Optional, cast
from typing import Any, Optional, Type, cast

from pydantic import BaseModel
from typing_extensions import override

from ..content import ContentBlock, Messages
Expand Down Expand Up @@ -262,3 +263,13 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:

case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

@override
def structured_output(self, output_model: Type[BaseModel], prompt: Optional[str] = None) -> BaseModel:
"""Get structured output from the model.

Args:
output_model(Type[BaseModel]): The output model to use for the agent.
prompt(Optional[str]): The prompt to use for the agent. Defaults to None.
"""
return output_model()
File renamed without changes.
17 changes: 17 additions & 0 deletions tests/strands/types/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from typing import Type

import pytest
from pydantic import BaseModel

from strands.types.models import Model as SAModel


class Person(BaseModel):
name: str
age: int


class TestModel(SAModel):
def update_config(self, **model_config):
return model_config

def get_config(self):
return

def structured_output(self, output_model: Type[BaseModel]) -> BaseModel:
return output_model(name="test", age=20)

def format_request(self, messages, tool_specs, system_prompt):
return {
"messages": messages,
Expand Down Expand Up @@ -79,3 +90,9 @@ def test_converse(model, messages, tool_specs, system_prompt):
},
]
assert tru_events == exp_events


def test_structured_output(model):
response = model.structured_output(Person)

assert response == Person(name="test", age=20)
Loading