Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/rai_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "rai_core"
version = "2.5.4"
version = "2.6.4"
description = "Core functionality for RAI framework"
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
readme = "README.md"
Expand Down
68 changes: 53 additions & 15 deletions src/rai_core/rai/agents/langchain/core/megamind.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,27 +209,64 @@ def get_initial_megamind_state(task: str):
)


@dataclass
class PlanPrompts:
"""Configurable prompts for the planning step."""

objective_template: str = "You are given objective to complete: {original_task}"
steps_done_header: str = "Steps that were already done successfully:\n"
next_step_prompt: str = "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent."
first_step_prompt: str = (
"\nCome up with the first step and delegate it to selected agent."
)
completion_prompt: str = (
"\n\nWhen you decide that the objective is completed return response to user."
)

@classmethod
def default(cls):
"""Return default prompts."""
return cls()

@classmethod
def custom(cls, **kwargs):
"""Create custom prompts with overrides."""
default = cls.default()
for key, value in kwargs.items():
if hasattr(default, key):
setattr(default, key, value)
return default


def plan_step(
megamind_agent: BaseChatModel,
state: MegamindState,
prompts: Optional[PlanPrompts] = None,
context_providers: Optional[List[ContextProvider]] = None,
) -> MegamindState:
"""Initial planning step."""
if prompts is None:
prompts = PlanPrompts.default()

if "original_task" not in state:
state["original_task"] = state["messages"][0].content[0]["text"]
if "steps_done" not in state:
state["steps_done"] = []
if "step" not in state:
state["step"] = None

megamind_prompt = f"You are given objective to complete: {state['original_task']}"
megamind_prompt = prompts.objective_template.format(
original_task=state["original_task"]
)
if context_providers:
for provider in context_providers:
megamind_prompt += provider.get_context()
megamind_prompt += "\n"

# Add completed steps if any
if state["steps_done"]:
megamind_prompt += "\n\n"
megamind_prompt += "Steps that were already done successfully:\n"
megamind_prompt += prompts.steps_done_header
steps_done = "\n".join(
[f"{i + 1}. {step}" for i, step in enumerate(state["steps_done"])]
)
Expand All @@ -239,22 +276,17 @@ def plan_step(
if state["step"]:
if not state["step_success"]:
raise ValueError("Step success should be specified at this point")

megamind_prompt += "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent."
megamind_prompt += prompts.next_step_prompt

else:
megamind_prompt += "\n"
megamind_prompt += (
"Come up with the fist step and delegate it to selected agent."
)
megamind_prompt += prompts.first_step_prompt

megamind_prompt += prompts.completion_prompt

megamind_prompt += "\n\n"
megamind_prompt += (
"When you decide that the objective is completed return response to user."
)
messages = [
HumanMultimodalMessage(content=megamind_prompt),
]

# NOTE (jmatejcz) the response of megamind isnt appended to messages
# as Command from handoff instantly transitions to next node
megamind_agent.invoke({"messages": messages})
Expand All @@ -265,7 +297,8 @@ def create_megamind(
megamind_llm: BaseChatModel,
executors: List[Executor],
megamind_system_prompt: Optional[str] = None,
task_planning_prompt: Optional[str] = None,
anylyzer_prompt: Optional[str] = None,
plan_prompts: Optional[PlanPrompts] = None,
context_providers: List[ContextProvider] = [],
) -> CompiledStateGraph:
"""Create a megamind langchain agent
Expand All @@ -292,7 +325,7 @@ def create_megamind(
llm=executor.llm,
tools=executor.tools,
system_prompt=executor.system_prompt,
planning_prompt=task_planning_prompt,
planning_prompt=anylyzer_prompt,
)

handoff_tools.append(
Expand Down Expand Up @@ -325,7 +358,12 @@ def create_megamind(

graph = StateGraph(MegamindState).add_node(
"megamind",
partial(plan_step, megamind_agent, context_providers=context_providers),
partial(
plan_step,
megamind_agent,
context_providers=context_providers,
prompts=plan_prompts,
),
)
for agent_name, agent in executor_agents.items():
graph.add_node(agent_name, agent)
Expand Down