Skip to content

Commit

Permalink
Merge pull request #498 from julep-ai/SCRUM-15-agents-api-Implement-p…
Browse files Browse the repository at this point in the history
…rompt-step-if-tool_calls

feat: Add agent tools to completion data before sending to litellm in prompt
  • Loading branch information
HamadaSalhab authored Sep 16, 2024
2 parents 4ce4810 + 06666ee commit 5146369
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 25 deletions.
Binary file added .DS_Store
Binary file not shown.
Binary file added agents-api/.DS_Store
Binary file not shown.
25 changes: 24 additions & 1 deletion agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template
from ...models.tools.list_tools import list_tools


@activity.defn
Expand Down Expand Up @@ -34,6 +35,28 @@ async def prompt_step(context: StepContext) -> StepOutcome:
else "gpt-4o"
)

agent_tools = list_tools(
developer_id=context.execution_input.developer_id,
agent_id=context.execution_input.agent.id,
limit=128, # Max number of supported functions in OpenAI. See https://platform.openai.com/docs/api-reference/chat/create
offset=0,
sort_by="created_at",
direction="desc",
)

# Format agent_tools for litellm
formatted_agent_tools = [
{
"type": tool.type,
"function": {
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters,
},
}
for tool in agent_tools
]

if context.current_step.settings:
passed_settings: dict = context.current_step.settings.model_dump(
exclude_unset=True
Expand All @@ -43,11 +66,11 @@ async def prompt_step(context: StepContext) -> StepOutcome:

completion_data: dict = {
"model": agent_model,
"tools": formatted_agent_tools or None,
("messages" if isinstance(prompt, list) else "prompt"): prompt,
**agent_default_settings,
**passed_settings,
}

response = await litellm.acompletion(
**completion_data,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from .transition_step import original_transition_step


@activity.defn
async def raise_complete_async() -> None:
async def raise_complete_async(context: StepContext, output: StepOutcome) -> None:
# TODO: Create a transtition to "wait" and save the captured_token to the transition

captured_token = activity.info().task_token
captured_token = captured_token.decode('latin-1')
transition_info = CreateTransitionRequest(
current=context.cursor,
type="wait",
next=None,
output=output,
task_token=captured_token,
)

await original_transition_step(context, transition_info)

# await transition(context, output=output, type="wait", next=None, task_token=captured_token)

print("transition to wait called")
activity.raise_complete_async()
12 changes: 3 additions & 9 deletions agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@ async def transition_step(
context: StepContext,
transition_info: CreateTransitionRequest,
) -> Transition:
need_to_wait = transition_info.type == "wait"

# Get task token if it's a waiting step
if need_to_wait:
task_token = activity.info().task_token
transition_info.task_token = task_token

# Create transition
transition = create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
data=transition_info,
task_token=transition_info.task_token,
update_execution_status=True,
)

return transition


original_transition_step = transition_step
mock_transition_step = transition_step

transition_step = activity.defn(name="transition_step")(
transition_step if not testing else mock_transition_step
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
exprs = context.current_step.wait_for_input.info
output = await base_evaluate(exprs, context.model_dump())

result = StepOutcome(output=output)
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class CreateTransitionRequest(Transition):
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
metadata: dict[str, Any] | None = None
task_token: str | None = None


class CreateEntryRequest(BaseEntry):
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"error": [],
"cancelled": [],
# Intermediate states
"wait": ["resume", "cancelled"],
"wait": ["resume", "cancelled", "finish", "finish_branch"],
"resume": [
"wait",
"error",
Expand Down Expand Up @@ -100,7 +100,7 @@
"queued": [],
"awaiting_input": ["starting", "running"],
"cancelled": ["queued", "starting", "awaiting_input", "running"],
"succeeded": ["starting", "running"],
"succeeded": ["starting", "awaiting_input", "running"],
"failed": ["starting", "running"],
} # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def create_execution_transition(
task_id: UUID | None = None,
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()

data.metadata = data.metadata or {}
data.execution_id = execution_id

Expand All @@ -111,7 +110,7 @@ def create_execution_transition(
columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": task_token,
"task_token": str(task_token), # Converting to str for JSON serialisation
"transition_id": str(transition_id),
"execution_id": str(execution_id),
}
Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/tasks/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ async def update_execution(
await wf_handle.cancel()

case ResumeExecutionRequest():

token_data = get_paused_execution_token(
developer_id=x_developer_id, execution_id=execution_id
)
act_handle = temporal_client.get_async_activity_handle(
token_data["task_token"]
task_token=str.encode(token_data["task_token"], encoding="latin-1")
)
await act_handle.complete(data.input)

print("Resumed execution successfully")
case _:
raise HTTPException(status_code=400, detail="Invalid request data")
34 changes: 30 additions & 4 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,12 @@ async def run(

case WaitForInputStep(), StepOutcome(output=output):
workflow.logger.info("Wait for input step: Waiting for external input")
await transition(context, output=output, type="wait", next=None)

await transition(context, type="wait", output=output)

result = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
)

Expand All @@ -391,8 +393,33 @@ async def run(
output=response
): # FIXME: if not response.choices[0].tool_calls:
# SCRUM-15
workflow.logger.debug("Prompt step: Received response")
state = PartialTransition(output=response)
workflow.logger.debug(f"Prompt step: Received response: {response}")
if response["choices"][0]["finish_reason"] != "tool_calls":
workflow.logger.debug("Prompt step: Received response")
state = PartialTransition(output=response)
else:
workflow.logger.debug("Prompt step: Received tool call")
message = response["choices"][0]["message"]
tool_calls_input = message["tool_calls"]

# Enter a wait-for-input step to ask the developer to run the tool calls
tool_calls_results = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
)
# Feed the tool call results back to the model
# context.inputs.append(tool_calls_results)
context.current_step.prompt.append(message)
context.current_step.prompt.append(tool_calls_results)
new_response = await workflow.execute_activity(
task_steps.prompt_step,
context,
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600),
)
state = PartialTransition(
output=new_response.output, type="resume")

# case PromptStep(), StepOutcome(
# output=response
Expand Down Expand Up @@ -453,7 +480,6 @@ async def run(

# 4. Transition to the next step
workflow.logger.info(f"Transitioning after step {context.cursor.step}")

# The returned value is the transition finally created
final_state = await transition(context, state)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ async def transition(
state.type = "finish_branch"
case _, _:
state.type = "step"

transition_request = CreateTransitionRequest(
current=context.cursor,
**{
Expand Down
2 changes: 1 addition & 1 deletion agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ x--shared-environment: &shared-environment
AGENTS_API_KEY: ${AGENTS_API_KEY}
AGENTS_API_KEY_HEADER_NAME: ${AGENTS_API_KEY_HEADER_NAME:-Authorization}
AGENTS_API_HOSTNAME: ${AGENTS_API_HOSTNAME:-localhost}
AGENTS_API_PROTOCOL: ${AGENTS_API_PROTOCOL:-http}
AGENTS_API_PUBLIC_PORT: ${AGENTS_API_PUBLIC_PORT:-80}
AGENTS_API_PROTOCOL: ${AGENTS_API_PROTOCOL:-http}
AGENTS_API_URL: ${AGENTS_API_URL:-http://agents-api:8080}
COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN}
COZO_HOST: ${COZO_HOST:-http://memory-store:9070}
Expand Down
11 changes: 10 additions & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ async def _(
mock_run_task_execution_workflow.assert_called_once()

# Let it run for a bit
await asyncio.sleep(1)
await asyncio.sleep(3)

# Get the history
history = await handle.fetch_history()
Expand All @@ -497,6 +497,15 @@ async def _(
activity for activity in activities_scheduled if activity
]

try:
future = handle.result()
breakpoint()
await future
except BaseException as exc:
print("exc", exc)
breakpoint()
raise

assert "wait_for_input_step" in activities_scheduled


Expand Down

0 comments on commit 5146369

Please sign in to comment.