Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 175 additions & 30 deletions src/draive/stages/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
LMMInput,
LMMInstruction,
LMMOutputSelection,
LMMToolRequest,
LMMToolRequests,
LMMToolResponse,
LMMToolResponses,
)
from draive.multimodal import Multimodal, MultimodalContent
Expand All @@ -43,8 +45,8 @@
StageRouting,
StageState,
)
from draive.tools import Tool, Toolbox
from draive.utils.processing import Processing
from draive.tools import FunctionTool, Tool, Toolbox
from draive.utils import Processing

__all__ = (
"Stage",
Expand Down Expand Up @@ -138,7 +140,8 @@ def predefined(
This Stage adds the specified elements to the LMM context and
sets the result to the last provided completion value. It's useful for
injecting static content into a processing pipeline.
Note that context must always end with LMMCompletion as the result.
Content of a last LMMCompletion in the context will be used as stage result.
Note that context should always end with LMMCompletion as the result.

Parameters
----------
Expand Down Expand Up @@ -174,22 +177,24 @@ def predefined(
context = [LMMInput.of(MultimodalContent.of(content))]

for idx, value in enumerate(elements):
if idx % 2 == 0:
if isinstance(value, LMMContextElement):
context.append(value)

else:
context.append(LMMInput.of(MultimodalContent.of(value)))

elif isinstance(value, LMMContextElement):
if isinstance(value, LMMContextElement):
context.append(value)

elif idx % 2 == 0:
context.append(LMMInput.of(MultimodalContent.of(value)))

else:
context.append(LMMCompletion.of(MultimodalContent.of(value)))

context_extension: LMMContext = tuple(context)
assert isinstance(context_extension[-1], LMMCompletion) # nosec: B101
completion_result: MultimodalContent = context_extension[-1].content
completion_result: MultimodalContent | None = next(
(
item.content
for item in reversed(context_extension)
if isinstance(item, LMMCompletion)
),
None,
)

async def stage(
*,
Expand All @@ -198,10 +203,13 @@ async def stage(
with ctx.scope("stage.predefined"):
return state.updated(
context=(*state.context, *context_extension),
result=completion_result,
result=completion_result if completion_result is not None else state.result,
)

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def memory_recall(
Expand Down Expand Up @@ -255,7 +263,10 @@ async def stage(
with ctx.scope("stage.memory_recall"):
return state.merged(await memory.recall())

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def memory_remember(
Expand Down Expand Up @@ -292,7 +303,10 @@ async def stage(
await memory.remember(state)
return state

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def completion(
Expand Down Expand Up @@ -376,7 +390,10 @@ async def stage(
result=result,
)

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def prompting_completion(
Expand Down Expand Up @@ -461,7 +478,10 @@ async def stage(
result=result,
)

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def loopback_completion(
Expand Down Expand Up @@ -515,15 +535,15 @@ async def stage(
return state

# Find the index of the last LMMInput in the context
last_input_idx: int = -1
last_input_idx: int = len(state.context)
for idx, element in enumerate(reversed(state.context)):
if isinstance(element, LMMInput):
last_input_idx = len(state.context) - idx - 1
break

else:
# using whole context as is
ctx.log_warning("loopback_completion could not find an LMMInput in the context")
return state

context, result = await _lmm_completion(
instruction=instruction,
Expand All @@ -540,7 +560,10 @@ async def stage(
result=result,
)

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def transform_result(
Expand Down Expand Up @@ -576,7 +599,10 @@ async def stage(
with ctx.scope("stage.transform.result"):
return state.updated(result=await transformation(state.result))

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def transform_context(
Expand Down Expand Up @@ -612,7 +638,10 @@ async def stage(
with ctx.scope("stage.transform.context"):
return state.updated(context=await transformation(state.context))

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def trim_context(
Expand Down Expand Up @@ -667,7 +696,10 @@ async def stage(
) -> StageState:
return state.updated(context=state.context[index_slice])

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def strip_context_tools(
Expand Down Expand Up @@ -701,7 +733,108 @@ async def stage(
),
)

return cls(stage, meta=Meta.of(meta))
return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def tool_call[**Args, Result](
cls,
tool: FunctionTool[Args, Result],
/,
*args: Args.args,
**kwargs: Args.kwargs,
) -> Self:
"""
Creates a Stage that executes a tool and adds its result to the context.

This Stage calls the provided tool with the given arguments and adds
the tool call to the LMM context as proper tool request/response pairs.
This makes the tool interaction visible in the context for subsequent stages.

Warning:
--------
Models usually require tool calls to be between regular input/completion messages.
You may need to manually adjust context afterwards to ensure proper contents.

Parameters
----------
tool : FunctionTool[Args, Result]
The tool to execute.
*args : Args.args
Positional arguments to pass to the tool.
**kwargs : Args.kwargs
Keyword arguments to pass to the tool.

Returns
-------
Self
A new Stage instance that executes the tool and adds its result to context.

Examples
--------
>>> @tool
... async def get_weather(location: str) -> str:
... return f"Weather in {location}: sunny"
...
>>> stage = Stage.tool_call(get_weather, location="New York")
"""
assert not args, "Positional arguments are not supported" # nosec: B101
direct_result: bool
match tool.handling:
case "auto":
direct_result = False

case "direct":
direct_result = True

async def stage(
*,
state: StageState,
) -> StageState:
with ctx.scope("stage.tool_call"):
# Create tool request representing the call
tool_request: LMMToolRequest = LMMToolRequest.of(
uuid4().hex,
tool=tool.name,
arguments=kwargs,
)

tool_response: LMMToolResponse
try:
# Execute the tool directly
result = await tool(*args, **kwargs)

# Create tool response with the result
tool_response = LMMToolResponse.of(
tool_request.identifier,
tool=tool.name,
content=MultimodalContent.of(tool.format_result(result)),
handling="direct_result" if direct_result else "result",
)

except Exception as exc:
tool_response = LMMToolResponse.of(
tool_request.identifier,
tool=tool.name,
content=MultimodalContent.of(tool.format_failure(exc)),
handling="direct_result" if direct_result else "error",
)

return state.updated(
context=(
*state.context,
LMMToolRequests.of((tool_request,)),
LMMToolResponses.of((tool_response,)),
),
result=tool_response.content if direct_result else state.result,
)

return cls(
stage,
meta=Meta.of({"tool_call": tool.name}),
)

@classmethod
def loop(
Expand Down Expand Up @@ -787,7 +920,10 @@ async def stage_loop(

iteration += 1

return cls(stage_loop, meta=Meta.of(meta))
return cls(
stage_loop,
meta=Meta.of(meta),
)

@classmethod
def sequence(
Expand Down Expand Up @@ -830,7 +966,10 @@ async def stage_sequence(

return current_state

return cls(stage_sequence, meta=Meta.of(meta))
return cls(
stage_sequence,
meta=Meta.of(meta),
)

@classmethod
def concurrent(
Expand Down Expand Up @@ -883,7 +1022,10 @@ async def concurrent_stage(
)
)

return cls(concurrent_stage, meta=Meta.of(meta))
return cls(
concurrent_stage,
meta=Meta.of(meta),
)

@classmethod
def router(
Expand Down Expand Up @@ -995,7 +1137,10 @@ async def router_stage(
f" - missing selection ({selection}) in available options"
)

return cls(router_stage, meta=Meta.of(meta))
return cls(
router_stage,
meta=Meta.of(meta),
)

__slots__ = (
"_execution",
Expand Down
Loading