Skip to content

Add Dataclasses and RepoEnv Info refac #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 40 additions & 59 deletions example_agent/example_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def build_system_prompt(self, info):
system_prompt["Overall task"] = (
"Your goal is to debug a Python program to make sure it can pass a set of test functions. You have access to the pdb debugger tools, you can use them to investigate the code, set breakpoints, and print necessary values to identify the bugs. Once you have gained enough information, propose a rewriting patch to fix the bugs. Avoid rewriting the entire code, focus on the bugs only."
)
system_prompt["Instructions"] = info["instructions"]
system_prompt["Repo directory tree"] = info["dir_tree"]
system_prompt["Current code in view"] = info["current_code_with_line_number"]
system_prompt["Current breakpoints"] = info["current_breakpoints"]
system_prompt["Last execution output"] = info["last_run_obs"]
system_prompt["Instructions"] = info.instructions
system_prompt["Repo directory tree"] = info.dir_tree
system_prompt["Current code in view"] = info.current_code_with_line_number
system_prompt["Current breakpoints"] = info.current_breakpoints
system_prompt["Last execution output"] = info.last_run_obs
system_prompt = unescape(json.dumps(system_prompt, indent=4))
messages = [
{
Expand All @@ -80,49 +80,39 @@ def build_prompt(self, info):

def run(self, task_name=None, debug=False):
self.history.reset()
_, info = self.env.reset(options={"task_name": task_name})
self.history.step(info)
info = self.env.reset(options={"task_name": task_name})
# initial state does not have prompt and response
self.history.step(info, None)

if info["done"] is True:
if info.done is True:
# msg = "Environment started with entrypoint passing without errors."
return True

done = False
highscore = info["score"]
highscore = info.score

for step in self.logger.tqdm(range(self.config["max_steps"])):
highscore = max(highscore, info["score"])
highscore = max(highscore, info.score)
self.logger.info(
f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%}) [Best: {highscore}]".format(
info["score"]
)
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
)

prompt = self.build_prompt(info)
answer, token_usage = self.llm(
llm_response = self.llm(
prompt, info, temperature=self.config["llm_temperature"][0]
)

if debug:
breakpoint()

_, _, done, info = self.env.step(answer)
info["token_usage"] = [
token_usage
] # in some other agents this is a list because of multi-step llm calls
self.history.step(info)
self.history.save_prompt_response_pairs(
prompt_response_pairs=[(prompt, answer)]
)
info = self.env.step(llm_response.response)
self.history.step(info, llm_response)

if done or info["rewrite_counter"] >= self.config["max_rewrite_steps"]:
if info.done or info.rewrite_counter >= self.config["max_rewrite_steps"]:
self.logger.info(
f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%})".format(
info["score"]
)
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%})"
)
break
return done
return info.done

def apply_patch(self, patch_path: str) -> bool:
patch_command = ["patch", "-p1"]
Expand Down Expand Up @@ -189,11 +179,11 @@ def build_system_prompt(self, info):
system_prompt["Overall task"] = (
"Your goal is to debug a Python program to make sure it can pass a set of test functions. You need to propose a rewriting patch to fix the bugs. Avoid rewriting the entire code, focus on the bugs only."
)
system_prompt["Instructions"] = info["instructions"]
system_prompt["Repo directory tree"] = info["dir_tree"]
system_prompt["Current code in view"] = info["current_code_with_line_number"]
system_prompt["Current breakpoints"] = info["current_breakpoints"]
system_prompt["Last execution output"] = info["last_run_obs"]
system_prompt["Instructions"] = info.instructions
system_prompt["Repo directory tree"] = info.dir_tree
system_prompt["Current code in view"] = info.current_code_with_line_number
system_prompt["Current breakpoints"] = info.current_breakpoints
system_prompt["Last execution output"] = info.last_run_obs
system_prompt = unescape(json.dumps(system_prompt, indent=4))
messages = [
{
Expand Down Expand Up @@ -223,58 +213,49 @@ def run(self, task_name=None, debug=False):
pdb_tool = self.env.tools.pop("pdb")

self.history.reset()
_, info = self.env.reset(options={"task_name": task_name})
self.history.step(info)
info = self.env.reset(options={"task_name": task_name})
# initial state does not have prompt and response
self.history.step(info, None)

if info["done"] is True:
if info.done is True:
# msg = "Environment started with entrypoint passing without errors."
return True

done = False
highscore = info["score"]
highscore = info.score

for step in self.logger.tqdm(range(self.config["max_steps"])):
highscore = max(highscore, info["score"])
highscore = max(highscore, info.score)
self.logger.info(
f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%}) [Best: {highscore}]".format(
info["score"]
)
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
)

prompt = self.build_prompt(info)
answer, token_usage = self.llm(

llm_response = self.llm(
prompt, info, temperature=self.config["llm_temperature"][0]
)

if debug:
breakpoint()

_, _, done, info = self.env.step(answer)
info["token_usage"] = [
token_usage
] # in some other agents this is a list because of multi-step llm calls
info = self.env.step(llm_response.response)

# re-introduce pdb tool at the right time
if (
info["rewrite_counter"] >= self.config["n_rewrites_before_pdb"]
info.rewrite_counter >= self.config["n_rewrites_before_pdb"]
and pdb_tool.name not in self.env.tools
):
self.env.add_tool(pdb_tool)
self.env.tools["pdb"].start_pdb()
info["instructions"] = self.env.instructions
info["obs"] += "\nThe pdb tool has been added."
info.instructions = self.env.instructions
info.obs += "\nThe pdb tool has been added."

self.history.step(info)
self.history.save_prompt_response_pairs(
prompt_response_pairs=[(prompt, answer)]
)
self.history.step(info, llm_response)

if done or info["rewrite_counter"] >= self.config["max_rewrite_steps"]:
if info.done or info.rewrite_counter >= self.config["max_rewrite_steps"]:
self.logger.info(
f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%})".format(
info["score"]
)
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%})"
)
break

return done
return info.done
84 changes: 61 additions & 23 deletions example_agent/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import sys
from dataclasses import dataclass

import tiktoken
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
Expand All @@ -14,7 +15,6 @@
from termcolor import colored
from transformers import AutoTokenizer

from example_agent.utils import trim_prompt_messages
from froggy.logger import FroggyLogger

prompt_toolkit_available = False
Expand Down Expand Up @@ -71,6 +71,34 @@ def merge_messages(messages):
return messages_out


@dataclass
class TokenUsage:
prompt: int
response: int


@dataclass
class LLMResponse:
prompt: list[dict] | str # either a string or a list of messages.
response: str
token_usage: TokenUsage | None = None

def __init__(
self,
prompt: list[dict] | str,
response: str,
prompt_token_count: int = None,
response_token_count: int = None,
token_usage: TokenUsage = None,
):
self.prompt = prompt
self.response = response
if prompt_token_count is not None and response_token_count is not None:
self.token_usage = TokenUsage(prompt_token_count, response_token_count)
else:
self.token_usage = token_usage


class TokenCounter:
def __init__(self, model: str = "gpt-4o"):
self.model = model
Expand All @@ -88,7 +116,7 @@ def __init__(self, model: str = "gpt-4o"):
)
raise ValueError(msg)

def __call__(self, *, messages=None, text=None):
def __call__(self, *, messages=None, text=None) -> int:
nb_tokens = 0
if messages is not None:
nb_tokens += sum(len(self.tokenize(msg["content"])) for msg in messages)
Expand Down Expand Up @@ -206,7 +234,9 @@ def query_model(self, messages, **kwargs):
)
return reponse.choices[0].message.content

def __call__(self, messages, *args, **kwargs):
def __call__(self, messages, *args, **kwargs) -> LLMResponse:
from example_agent.utils import trim_prompt_messages

if not self.config.get("system_prompt_support", True):
# Replace system by user
for i, m in enumerate(messages):
Expand All @@ -228,12 +258,13 @@ def __call__(self, messages, *args, **kwargs):

self.logger.debug(colored(response, "green"))

token_usage = {
"prompt": self.token_counter(messages=messages),
"response": self.token_counter(text=response),
}

return response, token_usage
llm_response = LLMResponse(
prompt=messages,
response=response,
prompt_token_count=self.token_counter(messages=messages),
response_token_count=self.token_counter(text=response),
)
return llm_response


class AsyncLLM(LLM):
Expand Down Expand Up @@ -262,7 +293,7 @@ async def query_model(self, messages, **kwargs):
)
return response.choices[0].message.content

async def __call__(self, messages, *args, **kwargs):
async def __call__(self, messages, *args, **kwargs) -> LLMResponse:
if not self.config.get("system_prompt_support", True):
# Replace system by user
for i, m in enumerate(messages):
Expand All @@ -272,12 +303,13 @@ async def __call__(self, messages, *args, **kwargs):
response = await self.query_model(messages, **kwargs)
response = response.strip()

token_usage = {
"prompt": self.token_counter(messages=messages),
"response": self.token_counter(text=response),
}

return response, token_usage
llm_response = LLMResponse(
prompt=messages,
response=response,
prompt_token_count=self.token_counter(messages=messages),
response_token_count=self.token_counter(text=response),
)
return llm_response


class Human:
Expand All @@ -287,10 +319,10 @@ def __init__(self, logger: FroggyLogger | None = None):
if prompt_toolkit_available:
self._history = InMemoryHistory()

def __call__(self, messages, info, *args, **kwargs):
def __call__(self, messages, info, *args, **kwargs) -> LLMResponse:
# Color each role differently.
print_messages(messages, self.logger)
available_commands = [t["template"] for t in info["tools"].values()]
available_commands = [t["template"] for t in info.tools.values()]
if prompt_toolkit_available:
actions_completer = WordCompleter(
available_commands, ignore_case=True, sentence=True
Expand All @@ -305,12 +337,18 @@ def __call__(self, messages, info, *args, **kwargs):
self.logger.info("\n".join(["Available commands:"] + available_commands))
action = input("> ")

token_usage = {
"prompt": len("\n".join([msg["content"] for msg in messages])),
"response": len(action),
}
prompt_messages = "\n".join([msg["content"] for msg in messages])
token_usage = TokenUsage(
prompt=len(prompt_messages),
response=len(action),
)

return action, token_usage
return LLMResponse(
prompt=prompt_messages,
response=action,
prompt_token_count=len(prompt_messages),
response_token_count=len(action),
)


def instantiate_llm(
Expand Down
38 changes: 38 additions & 0 deletions example_agent/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,44 @@

import pytest

from froggy.envs.env import EnvInfo


@pytest.fixture
def build_env_info():
def _env_info(
obs="obs",
max_score=10,
score=5,
last_run_obs="last_run_obs",
dbg_obs="dbg_obs",
dir_tree="dir_tree",
current_code_with_line_number="current_code_with_line_number",
current_breakpoints="current_breakpoints",
action="action",
instructions=None,
done=False,
rewrite_counter=0,
tools=None,
):
return EnvInfo(
obs=obs,
max_score=max_score,
score=score,
last_run_obs=last_run_obs,
dbg_obs=dbg_obs,
dir_tree=dir_tree,
current_code_with_line_number=current_code_with_line_number,
current_breakpoints=current_breakpoints,
action=action,
instructions=instructions if instructions is not None else {},
done=done,
rewrite_counter=rewrite_counter,
tools=tools if tools is not None else {},
)

return _env_info


@pytest.fixture
def open_data():
Expand Down
Loading