Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "uv_build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.84.6"
version = "0.84.7"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand Down
4 changes: 2 additions & 2 deletions src/draive/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from mcp.types import TextContent as MCPTextContent
from pydantic import AnyUrl

from draive.models import FunctionTool, Tool, ToolError, ToolsProvider
from draive.models import FunctionTool, ModelToolSpecification, Tool, ToolError, ToolsProvider
from draive.multimodal import ArtifactContent, MultimodalContent, TextContent
from draive.parameters import DataModel, validated_tool_specification
from draive.resources import ResourceContent, ResourceReference, ResourcesRepository
Expand Down Expand Up @@ -502,7 +502,7 @@ async def remote_call(**arguments: Any) -> MultimodalContent:

def _available(
tools_turn: int,
meta: Meta,
specification: ModelToolSpecification,
) -> bool:
return True

Expand Down
25 changes: 15 additions & 10 deletions src/draive/mistral/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,35 +203,40 @@ async def _completion_stream( # noqa: C901, PLR0912

else:
for tool_call in tool_calls:
assert tool_call.index, "Can't identify function call without index" # nosec: B101
if tool_call.index is None:
raise ModelOutputFailed(
provider="mistral",
model=config.model,
reason="Invalid completion: missing tool call index",
)

index: int = tool_call.index

# "null" is a dafault value...
if tool_call.id and tool_call.id != "null":
accumulated_tool_calls[tool_call.index].id = tool_call.id
accumulated_tool_calls[index].id = tool_call.id

if tool_call.function.name:
accumulated_tool_calls[
tool_call.index
].function.name += tool_call.function.name
accumulated_tool_calls[index].function.name += tool_call.function.name

if isinstance(tool_call.function.arguments, str):
assert isinstance( # nosec: B101
accumulated_tool_calls[tool_call.index].function.arguments,
accumulated_tool_calls[index].function.arguments,
str,
)
accumulated_tool_calls[ # pyright: ignore[reportOperatorIssue]
tool_call.index
index
].function.arguments += tool_call.function.arguments

else:
assert isinstance( # nosec: B101
accumulated_tool_calls[tool_call.index].function.arguments,
accumulated_tool_calls[index].function.arguments,
dict,
)
accumulated_tool_calls[tool_call.index].function.arguments = {
accumulated_tool_calls[index].function.arguments = {
**cast(
dict,
accumulated_tool_calls[tool_call.index].function.arguments,
accumulated_tool_calls[index].function.arguments,
),
**tool_call.function.arguments,
}
Expand Down
2 changes: 1 addition & 1 deletion src/draive/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from draive.models.tools import (
FunctionTool,
ModelToolSpecification,
Tool,
ToolAvailabilityChecking,
Toolbox,
Expand Down Expand Up @@ -66,7 +67,6 @@
ModelToolRequest,
ModelToolResponse,
ModelToolsDeclaration,
ModelToolSpecification,
ModelToolsSelection,
)

Expand Down
55 changes: 55 additions & 0 deletions src/draive/models/generative.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from asyncio import ALL_COMPLETED, FIRST_COMPLETED, Task, wait
from collections.abc import (
AsyncGenerator,
Collection,
Generator,
MutableSequence,
Sequence,
Expand Down Expand Up @@ -314,6 +315,11 @@ async def _loop(
output: ModelOutputSelection,
**extra: Any,
) -> ModelOutput:
"""Run the non-streaming tool loop and return the final output.

Executes repeated ``generating`` calls, dispatches tool requests, and merges tool
responses back into the mutable ``context`` until a terminal output is produced.
"""
ctx.log_debug("GenerativeModel loop started...")
tools_turn: int = 0
result_extension: Sequence[ModelReasoning | MultimodalContent] = ()
Expand Down Expand Up @@ -409,6 +415,12 @@ async def _streaming_loop( # noqa: C901, PLR0912
output: ModelOutputSelection,
**extra: Any,
) -> AsyncGenerator[ModelStreamOutput]:
"""Stream chunks while coordinating tool execution between turns.

Accumulates streamed model output, forwards tool requests to the ``toolbox`` as they
appear, yields intermediate content, and updates ``context`` until the session
completes.
"""
ctx.log_debug("GenerativeModel streaming loop started...")
tools_turn: int = 0
result_extension: Sequence[ModelReasoning | MultimodalContent] = ()
Expand Down Expand Up @@ -690,6 +702,31 @@ async def session(
session_preparing: ModelSessionPreparing


def _matches_modalities(
part: MultimodalContentPart,
*,
allowed: Collection[Literal["text", "image", "audio", "video"]],
) -> bool:
if not allowed:
return False

if "text" in allowed and isinstance(part, TextContent):
return True

if isinstance(part, ResourceContent | ResourceReference):
mime_type: str = part.mime_type or ""
if "image" in allowed and mime_type.startswith("image"):
return True

if "audio" in allowed and mime_type.startswith("audio"):
return True

if "video" in allowed and mime_type.startswith("video"):
return True

return False


def _decoded( # noqa: PLR0911
result: ModelOutput,
/,
Expand Down Expand Up @@ -726,6 +763,14 @@ def _decoded( # noqa: PLR0911
case "video":
return result.updated(blocks=(MultimodalContent.of(*result.content.video()),))

case selection if isinstance(selection, Collection) and not isinstance(selection, str):
selected_parts: tuple[MultimodalContentPart, ...] = tuple(
part
for part in result.content.parts
if _matches_modalities(part, allowed=set(selection))
)
return result.updated(blocks=(MultimodalContent.of(*selected_parts),))

case "json":
return result.updated(
blocks=(
Expand Down Expand Up @@ -811,6 +856,16 @@ async def _decoded_stream( # noqa: C901, PLR0912

# skip non video output

case selection if isinstance(selection, Collection) and not isinstance(selection, str):
async for part in stream:
if isinstance(part, MultimodalContentPart) and _matches_modalities(
part,
allowed=set(selection),
):
yield part

# skip non matching output

case "json":
accumulator: list[str] = []
async for part in stream:
Expand Down
149 changes: 115 additions & 34 deletions src/draive/models/instructions/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def volatile(
----------
instructions : dict[str, str]
Mapping of instruction name to its content.

Returns
-------
InstructionsRepository
Repository facade configured with volatile in-memory storage.
"""
volatile_storage: InstructionsVolatileStorage = InstructionsVolatileStorage(
_declarations={
Expand Down Expand Up @@ -108,6 +113,11 @@ def file(
----------
path : Path | str
Path to the storage file. The file is created on first write.

Returns
-------
InstructionsRepository
Repository facade configured with file-backed storage.
"""
file_storage: InstructionsFileStorage = InstructionsFileStorage(path=path)

Expand Down Expand Up @@ -137,7 +147,18 @@ async def available_instructions(
self,
**extra: Any,
) -> Sequence[InstructionsDeclaration]:
"""List available instruction declarations from the backend."""
"""List available instruction declarations from the backend.

Parameters
----------
**extra : Any
Extra keyword arguments forwarded to the underlying listing callable.

Returns
-------
Sequence[InstructionsDeclaration]
Declarations available in the configured storage backend.
"""
return await self.listing(
**extra,
)
Expand Down Expand Up @@ -198,6 +219,12 @@ async def resolve(
-------
str
Resolved instructions content.

Raises
------
InstructionsMissing
Raised when the referenced instructions cannot be loaded and no default
fallback is provided.
"""
if instructions is None:
return (
Expand Down Expand Up @@ -241,6 +268,25 @@ async def load(
/,
**extra: Any,
) -> ModelInstructions:
"""Load the raw instructions content from the backend.

Parameters
----------
instructions : Instructions
Instructions reference identifying what to load.
**extra : Any
Extra keyword arguments forwarded to the loading callable.

Returns
-------
str
Loaded instructions content.

Raises
------
InstructionsMissing
Raised when the repository backend has no content for the reference.
"""
loaded_instructions: str | None = await self.loading(
instructions.name,
meta=instructions.meta,
Expand Down Expand Up @@ -281,7 +327,22 @@ async def define(
content: str,
**extra: Any,
) -> None:
"""Create or update an instructions template in the backend."""
"""Create or update an instructions template in the backend.

Parameters
----------
instructions : InstructionsDeclaration | str
Template declaration or name describing what to define.
content : str
Template body that should be stored.
**extra : Any
Extra keyword arguments forwarded to the defining callable.

Returns
-------
None
This method performs a side effect on the configured backend.
"""
declaration: InstructionsDeclaration
if isinstance(instructions, str):
declaration = InstructionsDeclaration.of(instructions)
Expand Down Expand Up @@ -319,7 +380,20 @@ async def remove(
/,
**extra: Any,
) -> None:
"""Remove an instructions template from the backend."""
"""Remove an instructions template from the backend.

Parameters
----------
instructions : InstructionsDeclaration
Template declaration identifying what to remove.
**extra : Any
Extra keyword arguments forwarded to the removing callable.

Returns
-------
None
This method performs a side effect on the configured backend.
"""
await self.removing(
name=instructions.name,
meta=instructions.meta,
Expand Down Expand Up @@ -462,37 +536,44 @@ async def _load_file(self) -> None:

declarations: MutableMapping[str, InstructionsDeclaration] = {}
contents: MutableMapping[str, str] = {}
match json.loads(file_contents):
case [*elements]:
for element in elements:
match element:
case {
"name": str() as name,
"arguments": [*arguments],
"content": str() as content,
"description": str() | None as description,
"meta": {**meta},
}:
declarations[name] = InstructionsDeclaration(
name=name,
arguments=[
InstructionsArgumentDeclaration.from_mapping(argument)
for argument in arguments
],
description=description,
meta=Meta.of(meta),
)
contents[name] = content

case _:
# skip with warning
ctx.log_warning(
"Invalid file instructions storage element, skipping..."
)

case _:
# empty storage with warning
ctx.log_warning("Invalid file instructions storage, using empty storage...")
try:
match json.loads(file_contents):
case [*elements]:
for element in elements:
match element:
case {
"name": str() as name,
"arguments": [*arguments],
"content": str() as content,
"description": str() | None as description,
"meta": {**meta},
}:
declarations[name] = InstructionsDeclaration(
name=name,
arguments=[
InstructionsArgumentDeclaration.from_mapping(argument)
for argument in arguments
],
description=description,
meta=Meta.of(meta),
)
contents[name] = content

case _:
# skip with warning
ctx.log_warning(
"Invalid file instructions storage element, skipping..."
)

case _:
# empty storage with error
ctx.log_error("Invalid file instructions storage, using empty storage...")

except Exception as exc:
ctx.log_error(
"Invalid file instructions storage, using empty storage...",
exception=exc,
)

object.__setattr__(
self,
Expand Down
Loading