Skip to content

Commit c020b6b

Browse files
author
Vamil Gandhi
committed
refactor: simplify outputSchema implementation based on PR feedback
- Use NotRequired[JSONSchema] instead of total=False for ToolSpec - Remove filter_tool_specs method from Model base class - Always filter outputSchema in streaming.py (not configurable) - Add TODO comment referencing issue #780 for future improvements - Remove example file and filtering tests as requested - Update test mocks to remove filter_tool_specs references This simplifies the implementation by unconditionally filtering outputSchema in the event loop until proper model-specific behavior handling is implemented in #780.
1 parent c471eb8 commit c020b6b

File tree

7 files changed

+16
-298
lines changed

7 files changed

+16
-298
lines changed

examples/output_schema_example.py

Lines changed: 0 additions & 112 deletions
This file was deleted.

src/strands/event_loop/streaming.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from typing import Any, AsyncGenerator, AsyncIterable, Optional
5+
from typing import Any, AsyncGenerator, AsyncIterable, Optional, cast
66

77
from ..models.model import Model
88
from ..types._events import (
@@ -337,8 +337,16 @@ async def stream_messages(
337337
logger.debug("model=<%s> | streaming messages", model)
338338

339339
messages = remove_blank_messages_content_text(messages)
340-
# Filter outputschema spec based on model configuration until all models supports it.
341-
filtered_tool_specs = model.filter_tool_specs(tool_specs) if tool_specs else None
340+
341+
# TODO(#780): Remove outputSchema filtering once model-specific behavior handling is implemented.
342+
# For now, we filter out outputSchema from all tool specs to ensure compatibility with all model providers.
343+
# Some providers (e.g., Bedrock) will throw validation errors if they receive unknown fields.
344+
filtered_tool_specs = None
345+
if tool_specs:
346+
filtered_tool_specs = cast(
347+
list[ToolSpec], [{k: v for k, v in spec.items() if k != "outputSchema"} for spec in tool_specs]
348+
)
349+
342350
chunks = model.stream(messages, filtered_tool_specs, system_prompt)
343351

344352
async for event in process_stream(chunks):

src/strands/models/model.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import abc
44
import logging
5-
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union, cast
5+
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
66

77
from pydantic import BaseModel
88

@@ -22,27 +22,6 @@ class Model(abc.ABC):
2222
standardized way to configure and process requests for different AI model providers.
2323
"""
2424

25-
def filter_tool_specs(self, tool_specs: Optional[list[ToolSpec]]) -> Optional[list[ToolSpec]]:
26-
"""Filter tool specifications based on model configuration.
27-
28-
By default, this removes the outputSchema field from tool specs unless
29-
the model configuration explicitly enables it via `supports_tool_output_schema`.
30-
31-
Args:
32-
tool_specs: List of tool specifications to filter.
33-
34-
Returns:
35-
Filtered tool specifications safe for the model provider.
36-
"""
37-
if not tool_specs:
38-
return tool_specs
39-
40-
config = self.get_config()
41-
if isinstance(config, dict) and config.get("supports_tool_output_schema", False):
42-
return tool_specs
43-
44-
return cast(list[ToolSpec], [{k: v for k, v in spec.items() if k != "outputSchema"} for spec in tool_specs])
45-
4625
@abc.abstractmethod
4726
# pragma: no cover
4827
def update_config(self, **model_config: Any) -> None:

src/strands/types/tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dataclasses import dataclass
1010
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
1111

12-
from typing_extensions import TypedDict
12+
from typing_extensions import NotRequired, TypedDict
1313

1414
from .media import DocumentContent, ImageContent
1515

@@ -20,7 +20,7 @@
2020
"""Type alias for JSON Schema dictionaries."""
2121

2222

23-
class ToolSpec(TypedDict, total=False):
23+
class ToolSpec(TypedDict):
2424
"""Specification for a tool that can be used by an agent.
2525
2626
Attributes:
@@ -35,7 +35,7 @@ class ToolSpec(TypedDict, total=False):
3535
description: str
3636
inputSchema: JSONSchema
3737
name: str
38-
outputSchema: JSONSchema
38+
outputSchema: NotRequired[JSONSchema]
3939

4040

4141
class Tool(TypedDict):

tests/strands/agent/test_agent.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ async def stream(*args, **kwargs):
4949
mock = unittest.mock.Mock(spec=getattr(request, "param", None))
5050
mock.configure_mock(mock_stream=unittest.mock.MagicMock())
5151
mock.stream.side_effect = stream
52-
mock.filter_tool_specs = lambda tool_specs: tool_specs
53-
mock.get_config.return_value = {}
5452

5553
return mock
5654

tests/strands/event_loop/test_event_loop.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def mock_time():
3333

3434
@pytest.fixture
3535
def model():
36-
mock = unittest.mock.Mock()
37-
mock.filter_tool_specs.side_effect = lambda tool_specs: tool_specs
38-
return mock
36+
return unittest.mock.Mock()
3937

4038

4139
@pytest.fixture

tests/strands/models/test_output_schema_filtering.py

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)