Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TESTS_PATH := tests
-include .env

ifndef UV_VERSION
UV_VERSION := 0.8.17
UV_VERSION := 0.8.22
endif

.PHONY: uv_check venv sync update format lint test docs docs-server release
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "uv_build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.85.5"
version = "0.86.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand All @@ -21,7 +21,7 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
license = { file = "LICENSE" }
dependencies = ["numpy~=2.2", "haiway~=0.32.0"]
dependencies = ["numpy~=2.2", "haiway~=0.34.1"]

[project.urls]
Homepage = "https://miquido.com"
Expand Down
7 changes: 4 additions & 3 deletions src/draive/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from draive.models import FunctionTool, ModelToolSpecification, Tool, ToolError, ToolsProvider
from draive.multimodal import ArtifactContent, MultimodalContent, TextContent
from draive.parameters import DataModel, validated_tool_specification
from draive.parameters import DataModel, ToolParametersSpecification
from draive.resources import ResourceContent, ResourceReference, ResourcesRepository

__all__ = (
Expand Down Expand Up @@ -480,11 +480,12 @@ async def remote_call(**arguments: Any) -> MultimodalContent:
return FunctionTool(
name=name,
description=mcp_tool.description,
parameters=validated_tool_specification(
parameters=cast(
ToolParametersSpecification,
{
**mcp_tool.inputSchema,
"additionalProperties": False,
}
},
),
Comment on lines +483 to 489
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Validate/normalize MCP tool inputSchema instead of blind cast

mcp_tool.inputSchema can be None or a non-object schema. Spreading it and casting will raise at runtime and/or violate ToolParametersSpecification invariants (type must be "object"). Build a safe, strict object schema and propagate optional fields.

Apply this diff in-place:

-        parameters=cast(
-            ToolParametersSpecification,
-            {
-                **mcp_tool.inputSchema,
-                "additionalProperties": False,
-            },
-        ),
+        parameters=_coerce_tool_parameters_spec(mcp_tool.inputSchema),

Add this helper (outside the shown range, e.g., below _convert_tool or near module top):

def _coerce_tool_parameters_spec(
    input_schema: Mapping[str, Any] | None,
) -> ToolParametersSpecification:
    if not isinstance(input_schema, Mapping):
        ctx.log_warning("Missing or invalid MCP tool inputSchema; defaulting to strict object")
        return {
            "type": "object",
            "properties": {},
            "additionalProperties": False,
        }

    type_value = input_schema.get("type", "object")
    if type_value != "object":
        ctx.log_warning(f"Unexpected MCP tool inputSchema.type={type_value!r}; forcing 'object'")

    spec: ToolParametersSpecification = {
        "type": "object",
        "properties": cast(Mapping[str, Any], input_schema.get("properties", {})),
        "additionalProperties": False,
    }

    required = input_schema.get("required")
    if isinstance(required, Sequence):
        spec["required"] = [str(k) for k in required]  # type: ignore[assignment]

    title = input_schema.get("title")
    if isinstance(title, str):
        spec["title"] = title  # type: ignore[assignment]

    description = input_schema.get("description")
    if isinstance(description, str):
        spec["description"] = description  # type: ignore[assignment]

    return spec
🤖 Prompt for AI Agents
In src/draive/mcp/client.py around lines 483 to 489, the code blindly casts
mcp_tool.inputSchema into ToolParametersSpecification which can be None or a
non-object schema; create a helper _coerce_tool_parameters_spec(input_schema) as
described and use it instead of the cast so you always produce a strict object
schema (type "object", properties from input if mapping, additionalProperties
False) while preserving optional fields title, description, and required (coerce
required to strings if sequence), and emit ctx.log_warning when input_schema is
missing/invalid or its type is not "object"; place the helper near the module
top or below _convert_tool and replace the current parameters=cast(...) call
with parameters=_coerce_tool_parameters_spec(mcp_tool.inputSchema).

function=remote_call,
availability=_available,
Expand Down
32 changes: 16 additions & 16 deletions src/draive/openai/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import uuid4

from haiway import MISSING, Meta, Missing, ObservabilityLevel, as_dict, ctx, unwrap_missing
from openai import NOT_GIVEN, NotGiven
from openai import Omit, omit
from openai import RateLimitError as OpenAIRateLimitError
from openai.types.responses import (
Response,
Expand Down Expand Up @@ -217,9 +217,9 @@ async def _completion( # noqa: C901, PLR0912, PLR0915
try:
response: Response = await self._client.responses.create(
model=config.model,
instructions=instructions or NOT_GIVEN,
instructions=instructions or omit,
input=input_context,
temperature=unwrap_missing(config.temperature, default=NOT_GIVEN),
temperature=unwrap_missing(config.temperature, default=omit),
tool_choice=_tool_choice(tools),
tools=_tools_as_tool_params(tools.specifications),
text=_text_output(output, verbosity=config.verbosity),
Expand All @@ -229,18 +229,18 @@ async def _completion( # noqa: C901, PLR0912, PLR0915
summary=config.reasoning_summary,
)
if isinstance(config.reasoning, str)
else NOT_GIVEN
else omit
),
parallel_tool_calls=config.parallel_tool_calls,
max_output_tokens=config.max_output_tokens,
service_tier=config.service_tier,
truncation=config.truncation,
safety_identifier=config.safety_identifier or NOT_GIVEN,
prompt_cache_key=cache_key or NOT_GIVEN,
safety_identifier=config.safety_identifier or omit,
prompt_cache_key=cache_key or omit,
include=["reasoning.encrypted_content"]
# for gpt-5 model family we need to request encrypted reasoning
if "gpt-5" in config.model.lower()
else NOT_GIVEN,
else omit,
store=False,
stream=False,
)
Expand Down Expand Up @@ -463,9 +463,9 @@ async def _completion_stream( # noqa: C901, PLR0912, PLR0915
try:
async with self._client.responses.stream(
model=config.model,
instructions=instructions or NOT_GIVEN,
instructions=instructions or omit,
input=input_context,
temperature=unwrap_missing(config.temperature, default=NOT_GIVEN),
temperature=unwrap_missing(config.temperature, default=omit),
tool_choice=_tool_choice(tools),
tools=_tools_as_tool_params(tools.specifications),
text=_text_output(output, verbosity=config.verbosity),
Expand All @@ -475,18 +475,18 @@ async def _completion_stream( # noqa: C901, PLR0912, PLR0915
summary=config.reasoning_summary,
)
if isinstance(config.reasoning, str)
else NOT_GIVEN
else omit
),
parallel_tool_calls=config.parallel_tool_calls,
max_output_tokens=config.max_output_tokens,
service_tier=config.service_tier,
truncation=config.truncation,
safety_identifier=config.safety_identifier or NOT_GIVEN,
prompt_cache_key=cache_key or NOT_GIVEN,
safety_identifier=config.safety_identifier or omit,
prompt_cache_key=cache_key or omit,
include=["reasoning.encrypted_content"]
# for gpt-5 model family we need to request encrypted reasoning
if "gpt-5" in config.model.lower()
else NOT_GIVEN,
else omit,
store=False,
) as stream:
async for event in stream:
Expand Down Expand Up @@ -660,9 +660,9 @@ def _text_output( # noqa: PLR0911
/,
*,
verbosity: Literal["low", "medium", "high"] | Missing = MISSING,
) -> ResponseTextConfigParam | NotGiven:
) -> ResponseTextConfigParam | Omit:
if output == "auto":
return NOT_GIVEN
return omit

if output == "text":
if verbosity is MISSING:
Expand Down Expand Up @@ -713,7 +713,7 @@ def _text_output( # noqa: PLR0911
"verbosity": cast(Literal["low", "medium", "high"], verbosity),
}

return NOT_GIVEN
return omit


def _tool_choice(
Expand Down
6 changes: 3 additions & 3 deletions src/draive/openai/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import cast

from haiway import MISSING, Missing
from openai import NOT_GIVEN, NotGiven
from openai import Omit, omit

__all__ = ("unwrap_missing",)


def unwrap_missing[Value](
value: Value | Missing,
/,
) -> Value | NotGiven:
) -> Value | Omit:
if value is MISSING:
return NOT_GIVEN
return omit

else:
return cast(Value, value)
4 changes: 0 additions & 4 deletions src/draive/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
ParameterSpecification,
ParametersSpecification,
ToolParametersSpecification,
validated_specification,
validated_tool_specification,
)
from draive.parameters.validation import ParameterValidator, ParameterVerification

Expand All @@ -21,6 +19,4 @@
"ParametersSpecification",
"ParametrizedFunction",
"ToolParametersSpecification",
"validated_specification",
"validated_tool_specification",
)
26 changes: 26 additions & 0 deletions src/draive/parameters/coding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json
from datetime import date, datetime, time
from typing import (
Any,
)
from uuid import UUID

__all__ = ("ParametersJSONEncoder",)


class ParametersJSONEncoder(json.JSONEncoder):
def default(self, o: object) -> Any:
if isinstance(o, UUID):
return o.hex

elif isinstance(o, datetime):
return o.isoformat()

elif isinstance(o, time):
return o.isoformat()

elif isinstance(o, date):
return o.isoformat()

else:
return json.JSONEncoder.default(self, o)
27 changes: 12 additions & 15 deletions src/draive/parameters/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from inspect import _empty as INSPECT_EMPTY # pyright: ignore[reportPrivateUsage]
from inspect import signature
from types import EllipsisType
from typing import Any, ClassVar, cast, final, get_type_hints, overload
from typing import Any, ClassVar, Required, cast, final, get_type_hints, overload

from haiway import MISSING, DefaultValue, Missing, ValidationContext
from haiway.state import AttributeAnnotation
from haiway.state.attributes import resolve_attribute_annotation
from haiway.state.validation import Validator
from haiway import MISSING, AttributeAnnotation, DefaultValue, Missing, ValidationContext, Validator
from haiway.attributes.annotations import resolve_attribute
from haiway.utils import mimic_function

from draive.parameters.parameter import Parameter
Expand Down Expand Up @@ -226,20 +224,20 @@ def validate_arguments(
if self._variadic_keyword_parameters is None:
for parameter in self._parameters.values():
with ValidationContext.scope(f".{parameter.name}"):
validated[parameter.name] = parameter.validated(
validated[parameter.name] = parameter.validate(
parameter.find(kwargs),
)

else:
for parameter in self._parameters.values():
with ValidationContext.scope(f".{parameter.name}"):
validated[parameter.name] = parameter.validated(
validated[parameter.name] = parameter.validate(
parameter.pick(kwargs),
)

for key, value in kwargs.items():
with ValidationContext.scope(f".{key}"):
validated[key] = self._variadic_keyword_parameters.validated(
validated[key] = self._variadic_keyword_parameters.validate(
value,
)

Expand Down Expand Up @@ -286,11 +284,10 @@ def _resolve_argument(
parameter.name,
)

attribute: AttributeAnnotation = resolve_attribute_annotation(
attribute: AttributeAnnotation = resolve_attribute(
type_hint,
module=module,
type_parameters={},
self_annotation=None,
resolved_parameters={},
recursion_guard={},
)

Expand All @@ -310,9 +307,9 @@ def _resolve_argument(
verifier=argument.verifier,
converter=MISSING,
specification=argument.specification,
required=attribute.required
and argument.default is MISSING
and argument.default_factory is MISSING,
required=argument.default is MISSING
and argument.default_factory is MISSING
and Required in attribute.annotations,
)

case DefaultValue() as default: # pyright: ignore[reportUnknownVariableType]
Expand Down Expand Up @@ -340,5 +337,5 @@ def _resolve_argument(
verifier=MISSING,
converter=MISSING,
specification=MISSING,
required=attribute.required and value is INSPECT_EMPTY,
required=value is INSPECT_EMPTY and Required in attribute.annotations,
)
Loading