Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.71.1"
version = "0.71.2"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand Down
126 changes: 126 additions & 0 deletions src/draive/stages/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from haiway import Disposable, Disposables, State, cache, ctx, retry

from draive.commons import Meta, MetaValue, MetaValues
from draive.evaluation import (
EvaluatorResult,
PreparedEvaluator,
PreparedScenarioEvaluator,
ScenarioEvaluatorResult,
)
from draive.instructions import Instruction
from draive.lmm import (
LMM,
Expand Down Expand Up @@ -836,6 +842,126 @@ async def stage(
meta=Meta.of({"tool_call": tool.name}),
)

@classmethod
def result_evaluation(
cls,
evaluator: PreparedScenarioEvaluator[MultimodalContent]
| PreparedEvaluator[MultimodalContent],
/,
*,
meta: Meta | MetaValues | None = None,
) -> Self:
"""
Creates a Stage that evaluates the current result using an evaluator.

This Stage takes the current stage result and runs it through the provided
evaluator or scenario evaluator. The stage raises StageException when evaluation fails.

Parameters
----------
evaluator : PreparedScenarioEvaluator[MultimodalContent]
| PreparedEvaluator[MultimodalContent]
The evaluator or scenario evaluator to use for evaluation.
meta: Meta | MetaValues | None = None
Additional stage metadata including tags, description etc.

Returns
-------
Self
A new Stage instance that evaluates the result.

Examples
--------
>>> stage = Stage.result_evaluation(evaluator)
"""

async def stage(
*,
state: StageState,
) -> StageState:
with ctx.scope("stage.result_evaluation"):
evaluation_result: ScenarioEvaluatorResult | EvaluatorResult = await evaluator(
state.result
)

if evaluation_result.passed:
return state # evaluation passed, keep going

score: float = evaluation_result.relative_score
report: str = evaluation_result.report(include_details=__debug__)
raise StageException(
f"Result evaluation failed with relative score: {score *100:.2f}%",
state=state,
meta={
"evaluation_score": score,
"evaluation_report": report,
},
)

return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def context_evaluation(
cls,
evaluator: PreparedScenarioEvaluator[LMMContext] | PreparedEvaluator[LMMContext],
/,
*,
meta: Meta | MetaValues | None = None,
) -> Self:
"""
Creates a Stage that evaluates the current context using an evaluator.

This Stage takes the current LMM context and runs it through the provided
evaluator or scenario evaluator. The stage raises StageException when evaluation fails.

Parameters
----------
evaluator : PreparedScenarioEvaluator[Value] | PreparedEvaluator[Value]
The evaluator or scenario evaluator to use for evaluation.
meta: Meta | MetaValues | None = None
Additional stage metadata including tags, description etc.

Returns
-------
Self
A new Stage instance that evaluates the context.

Examples
--------
>>> stage = Stage.context_evaluation(evaluator)
"""

async def stage(
*,
state: StageState,
) -> StageState:
with ctx.scope("stage.context_evaluation"):
evaluation_result: ScenarioEvaluatorResult | EvaluatorResult = await evaluator(
state.context
)

if evaluation_result.passed:
return state # evaluation passed, keep going

score: float = evaluation_result.relative_score
report: str = evaluation_result.report(include_details=__debug__)
raise StageException(
f"Context evaluation failed with relative score: {score *100:.2f}%",
state=state,
meta={
"evaluation_score": score,
"evaluation_report": report,
},
)

return cls(
stage,
meta=Meta.of(meta),
)

@classmethod
def loop(
cls,
Expand Down
8 changes: 7 additions & 1 deletion src/draive/stages/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from haiway import MissingState, State

from draive.commons import Meta
from draive.commons import Meta, MetaValues
from draive.lmm import LMMCompletion, LMMContext
from draive.multimodal import Multimodal, MultimodalContent
from draive.parameters import DataModel
Expand Down Expand Up @@ -282,6 +282,7 @@ def __init__(
self,
*args: object,
state: StageState,
meta: Meta | MetaValues | None = None,
) -> None:
"""
Initialize a new StageException.
Expand All @@ -292,9 +293,14 @@ def __init__(
Exception arguments passed to the parent Exception class.
state : StageState
The stage state at the time the exception occurred.

meta : Meta | None = None
Additional exception metadata

"""
super().__init__(*args)
self.state: StageState = state
self.meta: Meta = Meta.of(meta)


@runtime_checkable
Expand Down