Skip to content

Commit

Permalink
Additional prompt engineering and output parsing
Browse files Browse the repository at this point in the history
Added and new debugging print commands and enabled pre-existing ones. Further prompt engineering to keep Llama on track and more string parsing to interpret the output. Removed an unused import.
  • Loading branch information
dibrale authored Apr 27, 2023
1 parent 849e5d6 commit 68209d8
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions babyagi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from dotenv import load_dotenv
import ast
import re

# default opt out of chromadb telemetry.
Expand Down Expand Up @@ -108,16 +107,21 @@ def can_import(module_name):
assert os.path.exists(LLAMA_MODEL_PATH), "\033[91m\033[1m" + f"Model can't be found." + "\033[0m\033[0m"

CTX_MAX = 2048
LLAMA_THREADS_NUM = int(os.getenv("LLAMA_THREADS_NUM", 4))
LLAMA_THREADS_NUM = int(os.getenv("LLAMA_THREADS_NUM", 8))
llm = Llama(
model_path=LLAMA_MODEL_PATH,
n_ctx=CTX_MAX, n_threads=LLAMA_THREADS_NUM,
n_ctx=CTX_MAX,
n_threads=LLAMA_THREADS_NUM,
n_batch=512,
use_mlock=True,
)
llm_embed = Llama(
model_path=LLAMA_MODEL_PATH,
n_ctx=CTX_MAX, n_threads=LLAMA_THREADS_NUM,
embedding=True, use_mlock=True,
n_ctx=CTX_MAX,
n_threads=LLAMA_THREADS_NUM,
n_batch=512,
embedding=True,
use_mlock=True,
)

print(
Expand Down Expand Up @@ -380,25 +384,39 @@ def openai_call(
def task_creation_agent(
objective: str, result: Dict, task_description: str, task_list: List[str]
):

prompt = f"""
You are to use the result from an execution agent to create new tasks with the following objective: {objective}.
The last completed task has the result: \n{result["data"]}
This result was based on this task description: {task_description}.\n"""
You are to use the result from an execution agent to create new tasks with the following objective: {objective}.
The last completed task has the result: \n{result["data"]}
This result was based on this task description: {task_description}.\n"""

if task_list:
prompt += f"These are incomplete tasks: {', '.join(task_list)}\n"
prompt += "Based on the result, create a list of new tasks to be completed in order to meet the objective. "
if task_list:
prompt += "These new tasks must not overlap with incomplete tasks. "
prompt += "Return all the new tasks, with one task per line in your response. Do not include any headers before your list."
# print(f'\n************** TASK CREATION AGENT PROMPT *************\n\n{prompt}\n')

prompt += """
Return all the new tasks, with one task per line in your response. The result must be a numbered list in the format:
#. First task
#. Second task
The number of each entry must be followed by a period.
Do not include any headers before your numbered list. Do not follow your numbered list with any other output."""

print(f'\n************** TASK CREATION AGENT PROMPT *************\n{prompt}\n')
response = openai_call(prompt, max_tokens=2000)
# print(f'\n************* TASK CREATION AGENT RESPONSE ************\n\n{response}\n')
print(f'\n************* TASK CREATION AGENT RESPONSE ************\n{response}\n')
new_tasks = response.split('\n')
new_tasks_list = []
for task_string in new_tasks:
task_name = ''.join([re.sub(r'[^\w\s_]+', '', task_string).strip()])
if not task_string == '':
new_tasks_list.append(task_name)
task_parts = task_string.strip().split(".", 1)
if len(task_parts) == 2:
task_id = ''.join(s for s in task_parts[0] if s.isnumeric())
task_name = re.sub(r'[^\w\s_]+', '', task_parts[1]).strip()
if task_name.strip() and task_id.isnumeric():
new_tasks_list.append(task_name)
# print('New task created: ' + task_name)

out = [{"task_name": task_name} for task_name in new_tasks_list]
Expand All @@ -408,20 +426,23 @@ def task_creation_agent(
def prioritization_agent():
task_names = tasks_storage.get_task_names()
next_task_id = tasks_storage.next_task_id()

prompt = f"""
You are tasked with cleaning the format and re-prioritizing the following tasks: {', '.join(task_names)}.
Consider the ultimate objective of your team: {OBJECTIVE}.
Tasks should be sorted from highest to lowest priority.
Higher-priority tasks are those that act as pre-requisites or are more essential for meeting the objective.
Do not remove any tasks. Return the result as a numbered list in the format:
#. First task
#. Second task
The entries are consecutively numbered, starting with 1. The number of each entry must be followed by a period.
Do not include any headers before your numbered list. Do not follow your numbered list with any other output."""
You are tasked with cleaning the format and re-prioritizing the following tasks: {', '.join(task_names)}.
Consider the ultimate objective of your team: {OBJECTIVE}.
Tasks should be sorted from highest to lowest priority.
Higher-priority tasks are those that act as pre-requisites or are more essential for meeting the objective.
Do not remove any tasks. Return the result as a numbered list in the format:
#. First task
#. Second task
The entries are consecutively numbered, starting with 1. The number of each entry must be followed by a period.
Do not include any headers before your numbered list. Do not follow your numbered list with any other output."""

print(f'\n************** TASK PRIORITIZATION AGENT PROMPT *************\n{prompt}\n')
response = openai_call(prompt, max_tokens=2000)
print(f'\n************* TASK PRIORITIZATION AGENT RESPONSE ************\n{response}\n')
new_tasks = response.split("\n") if "\n" in response else [response]
new_tasks_list = []
for task_string in new_tasks:
Expand Down Expand Up @@ -531,8 +552,10 @@ def main():
tasks_storage.get_task_names(),
)

print('Adding new tasks to task_storage')
for new_task in new_tasks:
new_task.update({"task_id": tasks_storage.next_task_id()})
print(str(new_task))
tasks_storage.append(new_task)

if not JOIN_EXISTING_OBJECTIVE: prioritization_agent()
Expand All @@ -543,5 +566,6 @@ def main():
print('Done.')
loop = False


if __name__ == "__main__":
main()

0 comments on commit 68209d8

Please sign in to comment.