Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.75.3"
version = "0.75.4"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand Down
35 changes: 20 additions & 15 deletions src/draive/helpers/instruction_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ async def _generate_strategy_metadata[
</format>
""" # noqa: E501

response = await Stage.completion(
response: MultimodalContent = await Stage.completion(
f"<failure_analysis>\n{failure_report}\n</failure_analysis>",
instruction=strategy_prompt,
).execute()
Expand All @@ -501,18 +501,18 @@ async def _generate_strategy_metadata[
"strategy",
content=response,
):
name_elem: MultimodalTagElement | None = MultimodalTagElement.parse_first(
name_element: MultimodalTagElement | None = MultimodalTagElement.parse_first(
"name",
content=strategy_element.content,
)
approach_elem: MultimodalTagElement | None = MultimodalTagElement.parse_first(
approach_element: MultimodalTagElement | None = MultimodalTagElement.parse_first(
"approach",
content=strategy_element.content,
)

if name_elem and approach_elem:
strategy_name: str = name_elem.content.to_str().strip()
strategy_approach: str = approach_elem.content.to_str().strip()
if name_element and approach_element:
strategy_name: str = name_element.content.to_str().strip()
strategy_approach: str = approach_element.content.to_str().strip()
strategies.append((strategy_name, strategy_approach))

return strategies
Expand All @@ -530,7 +530,7 @@ async def _generate_instruction_content[
guidelines: str | None,
) -> str:
# Analyze failures
failure_report = evaluation_result.report(
failure_report: str = evaluation_result.report(
include_passed=False,
include_details=True,
)
Expand Down Expand Up @@ -559,12 +559,17 @@ async def _generate_instruction_content[
</format>
""" # noqa: E501

response = await Stage.completion(
ctx.log_info(f"Generating updated instruction using {strategy_name}:\n\n{strategy_approach}")
response: MultimodalContent = await Stage.completion(
f"<failure_analysis>\n{failure_report}\n</failure_analysis>",
instruction=refinement_prompt,
).execute()

return response.to_str().strip()
updated_instruction: str = response.to_str().strip()

ctx.log_info(f"Prepared updated instruction using {strategy_name}:\n\n{updated_instruction}")

return updated_instruction


async def _generate_refined_instructions[
Expand All @@ -577,7 +582,7 @@ async def _generate_refined_instructions[
guidelines: str | None,
) -> Sequence[tuple[str, Instruction]]:
# Step 1: Generate strategy metadata
strategy_metadata = await _generate_strategy_metadata(
strategy_metadata: Sequence[tuple[str, str]] = await _generate_strategy_metadata(
evaluation_result=evaluation_result,
parent_strategy=parent_strategy,
guidelines=guidelines,
Expand All @@ -593,7 +598,7 @@ async def _generate_refined_instructions[
evaluation_result=evaluation_result,
guidelines=guidelines,
)
refined_instruction = instruction.updated(content=refined_content)
refined_instruction: Instruction = instruction.updated(content=refined_content)
refined_strategies.append((strategy_name, refined_instruction))

# Validate and return results
Expand Down Expand Up @@ -691,7 +696,10 @@ async def tree_finalization(
# Update result with best instruction and updated state
return state.updated(
refinement_state,
result=MultimodalContent.of(best_instruction.content),
result=MultimodalContent.of(
best_instruction.content,
meta={"score": best_score},
),
)

return tree_finalization
Expand Down Expand Up @@ -738,8 +746,6 @@ def _log_tree_statistics(
depth_distribution[node.depth] = depth_distribution.get(node.depth, 0) + 1

max_depth: int = max(depth_distribution.keys()) if depth_distribution else 0
theoretical_max: int = 2 ** (max_depth + 1) - 1
efficiency: float = active_nodes / theoretical_max if theoretical_max > 0.0 else 0.0

# Build path to best node
best_path: list[str] = []
Expand All @@ -757,6 +763,5 @@ def _log_tree_statistics(
f"- Active nodes: {active_nodes}\n"
f"- Pruned nodes: {pruned_count} ({pruned_count / total_nodes * 100:.1f}%)\n"
f"- Max depth reached: {max_depth}\n"
f"- Exploration efficiency: {efficiency:.1%}\n"
f"- Best path: {' -> '.join(best_path) if best_path else 'None'}"
)