Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/draive/anthropic/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,22 @@ async def _completion( # noqa: C901, PLR0912
attributes={
"model.provider": self._provider,
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.thinking_budget": config.thinking_budget,
"model.stop_sequences": config.stop_sequences,
"model.output": str(output),
"model.max_output_tokens": config.max_output_tokens,
"model.streaming": False,
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": False,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.context": [element.to_str() for element in context],
},
)

Expand Down
15 changes: 11 additions & 4 deletions src/draive/bedrock/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,21 @@ async def _completion(
attributes={
"model.provider": "bedrock",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.stop_sequences": config.stop_sequences,
"model.max_output_tokens": config.max_output_tokens,
"model.output": str(output),
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": False,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.context": [element.to_str() for element in context],
},
)

Expand Down
32 changes: 22 additions & 10 deletions src/draive/gemini/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,21 @@ async def _completion(
attributes={
"model.provider": "gemini",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.output": str(output),
"model.max_output_tokens": config.max_output_tokens,
"model.thinking_budget": config.thinking_budget,
"model.streaming": False,
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": False,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.context": [element.to_str() for element in context],
},
)

Expand Down Expand Up @@ -278,15 +284,21 @@ async def _completion_stream(
attributes={
"model.provider": "gemini",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.output": str(output),
"model.max_output_tokens": config.max_output_tokens,
"model.thinking_budget": config.thinking_budget,
"model.streaming": True,
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": True,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.context": [element.to_str() for element in context],
},
)

Expand Down
36 changes: 32 additions & 4 deletions src/draive/mistral/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ async def _completion_stream( # noqa: C901, PLR0912
prefill: Multimodal | None,
**extra: Any,
) -> AsyncGenerator[ModelStreamOutput]:
ctx.record(
ObservabilityLevel.INFO,
attributes={
"model.provider": "mistral",
"model.name": config.model,
"model.temperature": config.temperature,
"model.max_output_tokens": config.max_output_tokens,
"model.output": str(output),
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": True,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.context": [element.to_str() for element in context],
},
)
messages: list[MessagesTypedDict] = _build_messages(
context=context,
instructions=instructions,
Expand Down Expand Up @@ -291,13 +312,20 @@ async def _completion(
attributes={
"model.provider": "mistral",
"model.name": config.model,
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.max_output_tokens": config.max_output_tokens,
"model.output": str(output),
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": False,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.context": [element.to_str() for element in context],
},
)

Expand Down
14 changes: 10 additions & 4 deletions src/draive/ollama/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,19 @@ async def _completion(
attributes={
"model.provider": "ollama",
"model.name": config.model,
"model.temperature": config.temperature,
"model.output": str(output),
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": False,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.output": str(output),
"model.streaming": False,
},
)

Expand Down
12 changes: 10 additions & 2 deletions src/draive/stages/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ def result_evaluation(
| PreparedEvaluator[MultimodalContent],
/,
*,
raises: bool = False,
meta: Meta | MetaValues | None = None,
) -> Self:
"""
Expand All @@ -945,6 +946,9 @@ def result_evaluation(
evaluator : PreparedEvaluatorScenario[MultimodalContent]
| PreparedEvaluator[MultimodalContent]
The evaluator or scenario evaluator to use for evaluation.
raises: bool = False
Determines whether to raise ``StageException`` when the evaluation fails.
When ``False``, the stage returns the input state unchanged on failure.
meta: Meta | MetaValues | None = None
Additional stage metadata including tags, description etc.

Expand All @@ -967,7 +971,7 @@ async def stage(
state.result
)

if evaluation_result.passed:
if evaluation_result.passed or not raises:
return state # evaluation passed, keep going

performance: float = evaluation_result.performance
Expand All @@ -992,6 +996,7 @@ def context_evaluation(
evaluator: PreparedEvaluatorScenario[ModelContext] | PreparedEvaluator[ModelContext],
/,
*,
raises: bool = False,
meta: Meta | MetaValues | None = None,
) -> Self:
"""
Expand All @@ -1004,6 +1009,9 @@ def context_evaluation(
----------
evaluator : PreparedEvaluatorScenario[Value] | PreparedEvaluator[Value]
The evaluator or scenario evaluator to use for evaluation.
raises: bool = False
Determines whether to raise ``StageException`` when the evaluation fails.
When ``False``, the stage returns the input state unchanged on failure.
meta: Meta | MetaValues | None = None
Additional stage metadata including tags, description etc.

Expand All @@ -1026,7 +1034,7 @@ async def stage(
state.context
)

if evaluation_result.passed:
if evaluation_result.passed or not raises:
return state # evaluation passed, keep going

performance: float = evaluation_result.performance
Expand Down
30 changes: 22 additions & 8 deletions src/draive/vllm/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,20 @@ async def _completion(
attributes={
"model.provider": "vllm",
"model.name": config.model,
"model.temperature": config.temperature,
"model.max_output_tokens": config.max_output_tokens,
"model.output": str(output),
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": False,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.output": str(output),
"model.streaming": False,
},
)

Expand Down Expand Up @@ -272,13 +279,20 @@ async def _completion_stream( # noqa: C901, PLR0912, PLR0915
attributes={
"model.provider": "vllm",
"model.name": config.model,
"model.temperature": config.temperature,
"model.max_output_tokens": config.max_output_tokens,
"model.output": str(output),
"model.tools.count": len(tools.specifications),
"model.tools.selection": tools.selection,
"model.stream": True,
},
)
ctx.record(
ObservabilityLevel.DEBUG,
attributes={
"model.instructions": instructions,
"model.tools": [tool.name for tool in tools.specifications],
"model.tool_selection": tools.selection,
"model.context": [element.to_str() for element in context],
"model.temperature": config.temperature,
"model.output": str(output),
"model.streaming": True,
},
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,10 @@ async def __call__(self, result: MultimodalContent) -> EvaluatorResult:

# Test failing evaluation
failing_evaluator = MockEvaluator(should_pass=False)
fail_eval_stage = Stage.result_evaluation(failing_evaluator)
fail_raising_eval_stage = Stage.result_evaluation(failing_evaluator, raises=True)

with raises(StageException) as exc_info:
await fail_eval_stage(state=initial_state)
await fail_raising_eval_stage(state=initial_state)

assert "Result evaluation failed" in str(exc_info.value)
assert exc_info.value.meta["evaluation_performance"] == 60.0 # (0.3/0.5) * 100
Expand Down