Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "uv_build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.84.5"
version = "0.84.6"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand Down
8 changes: 4 additions & 4 deletions src/draive/anthropic/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def _completion( # noqa: C901, PLR0912
"model.provider": self._provider,
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -672,7 +672,7 @@ def _tools_as_tool_params(
tool_params: list[ToolParam] = []
for tool in tools:
input_schema: dict[str, Any]
if parameters := tool["parameters"]:
if parameters := tool.parameters:
input_schema = cast(dict[str, Any], parameters)

else:
Expand All @@ -685,8 +685,8 @@ def _tools_as_tool_params(

tool_params.append(
{
"name": tool["name"],
"description": tool["description"] or "",
"name": tool.name,
"description": tool.description or "",
"input_schema": input_schema,
}
)
Expand Down
8 changes: 4 additions & 4 deletions src/draive/bedrock/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def _completion(
"model.provider": "bedrock",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -430,9 +430,9 @@ def _convert_content( # noqa: C901

def _convert_tool(tool: ModelToolSpecification) -> ChatTool:
return {
"name": tool["name"],
"description": tool["description"] or "",
"inputSchema": {"json": tool["parameters"]},
"name": tool.name,
"description": tool.description or "",
"inputSchema": {"json": tool.parameters},
}


Expand Down
10 changes: 5 additions & 5 deletions src/draive/gemini/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def _completion(
"model.provider": "gemini",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -279,7 +279,7 @@ async def _completion_stream(
"model.provider": "gemini",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -466,9 +466,9 @@ def _prepare_request_config( # noqa: C901, PLR0912, PLR0915
if tools.specifications:
functions = [
FunctionDeclarationDict(
name=tool["name"],
description=tool["description"],
parameters_json_schema=cast(SchemaDict, tool["parameters"]),
name=tool.name,
description=tool.description,
parameters_json_schema=cast(SchemaDict, tool.parameters),
)
for tool in tools.specifications or []
]
Expand Down
6 changes: 3 additions & 3 deletions src/draive/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ async def list_tools() -> list[MCPTool]: # pyright: ignore[reportUnusedFunction
):
return [
MCPTool(
name=tool["name"],
description=tool["description"],
inputSchema=as_dict(tool["parameters"]) or {},
name=tool.name,
description=tool.description,
inputSchema=as_dict(tool.parameters) or {},
)
for tool in (
tool.specification for tool in toolbox.tools.values() if tool.available
Expand Down
8 changes: 4 additions & 4 deletions src/draive/mistral/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def _completion(
"model.provider": "mistral",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -571,9 +571,9 @@ def _tool_specification_as_tool(
return {
"type": "function",
"function": {
"name": tool["name"],
"description": tool["description"] or "",
"parameters": cast(dict[str, Any], tool["parameters"])
"name": tool.name,
"description": tool.description or "",
"parameters": cast(dict[str, Any], tool.parameters)
or {
"type": "object",
"properties": {},
Expand Down
1 change: 1 addition & 0 deletions src/draive/models/tools/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
"description": description,
"parameters": parameters,
"additionalProperties": False,
"meta": meta,
},
)
self.handling: ModelToolHandling
Expand Down
2 changes: 1 addition & 1 deletion src/draive/models/tools/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def available_tools_declaration(
tools_selection = "required"

else: # ModelToolSpecification
tools_selection = tool_suggestion["name"]
tools_selection = tool_suggestion.name

return ModelToolsDeclaration(
specifications=available_tools,
Expand Down
10 changes: 2 additions & 8 deletions src/draive/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Literal,
Protocol,
Self,
TypedDict,
final,
overload,
runtime_checkable,
Expand Down Expand Up @@ -214,16 +213,11 @@ def __init__(


@final
class ModelToolFunctionSpecification(TypedDict, total=True):
"""JSON Schema-like function specification exposed to the model.

Keys mirror OpenAI/JSON Schema conventions for function tools.
"""

class ModelToolFunctionSpecification(State):
name: str
description: str | None
parameters: ToolParametersSpecification | None
additionalProperties: Literal[False]
meta: Meta


ModelToolSpecification = ModelToolFunctionSpecification
Expand Down
10 changes: 5 additions & 5 deletions src/draive/ollama/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def _completion(
"model.provider": "ollama",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -344,11 +344,11 @@ def _tool_specification_as_tool(
return Tool(
type="function",
function=Tool.Function(
name=tool["name"],
description=tool["description"],
name=tool.name,
description=tool.description,
parameters=(
cast(Tool.Function.Parameters, tool["parameters"]) # type: ignore[arg-type]
if tool["parameters"]
cast(Tool.Function.Parameters, tool.parameters) # type: ignore[arg-type]
if tool.parameters
else {
"type": "object",
"properties": {},
Expand Down
8 changes: 4 additions & 4 deletions src/draive/openai/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def open_session() -> ModelSession: # noqa: C901, PLR0915
"model.input_audio_noise_reduction": config.input_audio_noise_reduction,
"model.vad": str(config.vad),
"model.voice": config.voice,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": f"{tools.selection}",
"model.output": f"{output}",
}
Expand Down Expand Up @@ -767,9 +767,9 @@ def _prepare_session_config(
without_missing(
{
"type": "function",
"name": tool["name"],
"description": tool.get("description", MISSING),
"parameters": tool.get("parameters", MISSING),
"name": tool.name,
"description": tool.description or MISSING,
"parameters": tool.parameters or MISSING,
},
)
for tool in tools.specifications
Expand Down
15 changes: 9 additions & 6 deletions src/draive/openai/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def _completion( # noqa: C901, PLR0912, PLR0915
"model.provider": "openai",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -422,7 +422,7 @@ async def _completion_stream( # noqa: C901, PLR0912, PLR0915
"model.provider": "openai",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -727,15 +727,18 @@ def _tools_as_tool_params(
ToolParam,
FunctionToolParam(
type="function",
name=tool["name"],
description=tool["description"] or None,
parameters=cast(dict[str, object] | None, tool["parameters"])
name=tool.name,
description=tool.description or None,
parameters=cast(dict[str, object] | None, tool.parameters)
or {
"type": "object",
"properties": {},
"additionalProperties": False,
},
strict=True,
strict=tool.meta.get_bool(
"strict_parameters",
default=False,
),
),
)
for tool in tools
Expand Down
10 changes: 10 additions & 0 deletions src/draive/parameters/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,15 @@ def _simplified_schema_property( # noqa: C901, PLR0912, PLR0911
case {"type": "object", "additionalProperties": True}:
return {}

case {"type": [*alternatives], "description": str() as description}:
return (
"|".join(alternatives) + f"({description})"
if description
else "|".join(alternatives)
)

case {"type": [*alternatives]}:
return "|".join(alternatives)

case other:
raise ValueError("Unsupported basic specification element: %s", other)
61 changes: 50 additions & 11 deletions src/draive/parameters/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
)


@final
class ParameterAlternativesSpecification(TypedDict, total=False):
type: Required[Sequence[Literal["string", "number", "integer", "boolean", "null"]]]
description: NotRequired[str]


@final
class ParameterNoneSpecification(TypedDict, total=False):
type: Required[Literal["null"]]
Expand Down Expand Up @@ -148,7 +154,8 @@ class ParameterAnyObjectSpecification(TypedDict, total=False):


type ParameterSpecification = (
ParameterUnionSpecification
ParameterAlternativesSpecification
| ParameterUnionSpecification
| ParameterNoneSpecification
| ParameterStringEnumSpecification
| ParameterStringSpecification
Expand Down Expand Up @@ -462,22 +469,54 @@ def _prepare_specification_of_union(
/,
description: str | None,
) -> ParameterSpecification:
compressed_alternatives: list[Literal["string", "number", "integer", "boolean", "null"]] = []
alternatives: list[ParameterSpecification] = []
for argument in annotation.arguments:
specification: ParameterSpecification = parameter_specification(
cast(AttributeAnnotation, argument),
description=None,
)
alternatives.append(specification)
match specification:
case {"type": "null", **tail} if not tail:
compressed_alternatives.append("null")

case {"type": "string", **tail} if not tail:
compressed_alternatives.append("string")

case {"type": "number", **tail} if not tail:
compressed_alternatives.append("number")

case {"type": "integer", **tail} if not tail:
compressed_alternatives.append("integer")

case {"type": "boolean", **tail} if not tail:
compressed_alternatives.append("boolean")

case _:
pass # skip - type is more complex and can't be compressed

if description := description:
if len(compressed_alternatives) == len(alternatives):
# prefer comperessed when equivalent representation is available
return ParameterAlternativesSpecification(
type=compressed_alternatives,
description=description,
)

return {
"oneOf": [
parameter_specification(cast(AttributeAnnotation, argument), description=None)
for argument in annotation.arguments
],
"oneOf": alternatives,
"description": description,
}

else:
return {
"oneOf": [
parameter_specification(cast(AttributeAnnotation, argument), description=None)
for argument in annotation.arguments
],
}
if len(compressed_alternatives) == len(alternatives):
# prefer comperessed when equivalent representation is available
return ParameterAlternativesSpecification(
type=compressed_alternatives,
)

return {"oneOf": alternatives}


def _prepare_specification_of_bool(
Expand Down
10 changes: 5 additions & 5 deletions src/draive/vllm/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def _completion(
"model.provider": "vllm",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -273,7 +273,7 @@ async def _completion_stream( # noqa: C901, PLR0912, PLR0915
"model.provider": "vllm",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool["name"] for tool in tools.specifications],
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
Expand Down Expand Up @@ -590,9 +590,9 @@ def _tool_specification_as_tool(tool: ModelToolSpecification) -> ChatCompletionT
return {
"type": "function",
"function": {
"name": tool["name"],
"description": tool["description"] or "",
"parameters": cast(dict[str, Any], tool["parameters"])
"name": tool.name,
"description": tool.description or "",
"parameters": cast(dict[str, Any], tool.parameters)
or {
"type": "object",
"properties": {},
Expand Down
Loading