Skip to content

Commit b9d089a

Browse files
authored
Merge pull request #58 from commit-0/agent
Update aider with better display
2 parents 1711fdf + f12b903 commit b9d089a

File tree

10 files changed

+1120
-157
lines changed

10 files changed

+1120
-157
lines changed

agent/README.md

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,73 @@
1-
# How to run baseline
1+
# Agent for Commit0
2+
This tool provides a command-line interface for configuring and running AI agents to assist with code development and testing.
23

3-
Step 1: Go to `config/aider.yaml` and change the config
4+
## Quick Start
5+
Configure an agent:
6+
```bash
7+
agent config [OPTIONS] AGENT_NAME
8+
```
49

5-
Step 2: Run the following command
10+
Run an agent on a specific branch:
11+
```bash
12+
agent run [OPTIONS] BRANCH
13+
```
614

15+
For more detailed information on available commands and options:
716
```bash
8-
python baselines/run_aider.py
17+
agent -h
18+
agent config -h
19+
agent run -h
920
```
21+
## Configure an Agent
22+
Use `agent config [OPTIONS] AGENT_NAME` to set up the configuration for an agent.
23+
Available options include:
24+
25+
`--agent_name: str`: Agent to use, we only support [aider](https://aider.chat/) for now. [Default: `aider`]
26+
`--model-name: str`: LLM model to use, check [here](https://aider.chat/docs/llms.html) for all supported models. [Default: `claude-3-5-sonnet-20240620`]
27+
`--use-user-prompt: bool`: Use a custom prompt instead of the default prompt. [Default: `False`]
28+
`--user-prompt: str`: The prompt sent to agent. [Default: See code for details.]
29+
`--run-tests: bool`: Run tests after code modifications for feedback. You need to set up `docker` or `modal` before running tests, refer to commit0 docs. [Default `False`]
30+
`--max-iteration: int`: Maximum number of agent iterations. [Default: `3`]
31+
`--use-repo-info: bool`: Include the repository information. [Default: `False`]
32+
`--max-repo-info-length: int`: Maximum length of the repository information to use. [Default: `10000`]
33+
`--use-unit-tests-info: bool`: Include the unit tests information. [Default: `False`]
34+
`--max-unit-tests-info-length: int`: Maximum length of the unit tests information to use. [Default: `10000`]
35+
`--use-spec-info: bool`: Include the spec information. [Default: `False`]
36+
`--max-spec-info-length: int`: Maximum length of the spec information to use. [Default: `10000`]
37+
`--use-lint-info: bool`: Include the lint information. [Default: `False`]
38+
`--max-lint-info-length: int`: Maximum length of the lint information to use. [Default: `10000`]
39+
`--pre-commit-config-path: str`: Path to the pre-commit config file. This is needed for running `lint`. [Default: `.pre-commit-config.yaml`]
40+
`--agent-config-file: str`: Path to write the agent config. [Default: `.agent.yaml`]
41+
42+
## Running Agent
43+
Use `agent run [OPTIONS] BRANCH` to execute an agent on a specific branch.
44+
Available options include:
45+
46+
`--branch: str`: Branch to run the agent on, you can specific the name of the branch
47+
`--backend: str`: Test backend to run the agent on, ignore this option if you are not adding `run_tests` option to agent. [Default: `modal`]
48+
`--log-dir: str`: Log directory to store the logs. [Default: `logs/aider`]
49+
`--max-parallel-repos: int`: Maximum number of repositories for agent to run in parallel. Running in sequential if set to 1. [Default: `1`]
50+
`--display-repo-progress-num: int`: Number of repo progress displayed when running. [Default: `5`]
51+
52+
53+
### Example: Running aider
54+
Step 1: Configure aider: `agent config aider`
55+
Step 2: Run aider on a branch: `agent run aider_branch`
56+
57+
### Other Agent:
58+
Refer to `class Agents` in `agent/agents.py`. You can design your own agent by inheriting `Agents` class and implement the `run` method.
59+
60+
## Notes
61+
62+
### Automatically retry
63+
Aider automatically retries certain API errors. For details, see [here](https://github.com/paul-gauthier/aider/blob/75e1d519da9b328b0eca8a73ee27278f1289eadb/aider/sendchat.py#L17).
64+
65+
### Parallelize agent running
66+
When increasing --max-parallel-repos, be mindful of aider's [60-second retry timeout](https://github.com/paul-gauthier/aider/blob/75e1d519da9b328b0eca8a73ee27278f1289eadb/aider/sendchat.py#L39). Set this value according to your API tier to avoid RateLimitErrors stopping processes.
67+
68+
### Large files in repo
69+
Currently, agent will skip file with more than 1500 lines. See `agent/agent_utils.py#L199` for details.
70+
71+
### Cost
72+
Running a full `all` commit0 split costs approximately $100.
1073

11-
## Config
12-
13-
`commit0_config`:
14-
15-
- `base_dir`: Repos dir. Default `repos`.
16-
- `dataset_name`: commit0 HF dataset name. Default: `wentingzhao/commit0_docstring`.
17-
- `dataset_split`: commit0 dataset split. Default: `test`.
18-
- `repo_split`: commit0 repo split. Default: `simpy`.
19-
- `num_workers`: number of workers to run in parallel. Default: `10`.
20-
21-
`aider_config`:
22-
23-
- `llm_name`: LLM model name. Default: `claude-3-5-sonnet-20240620`.
24-
- `use_user_prompt`: Whether to use user prompt. Default: `false`.
25-
- `user_prompt`: User prompt. Default: `""`.
26-
- `use_repo_info`: Whether to use repo info. Default: `false`.
27-
- Repo info
28-
- skeleton of the repo(filenames under each dir)
29-
- function stubs
30-
31-
- `use_unit_tests_info`: Whether to use unit tests: unit_tests that target will be tested with. Default: `false`.
32-
- `use_reference_info`: Whether to use reference: reference doc/pdf/website. Default: `false`.
33-
- `use_lint_info`: Whether to use lint: lint info. Default: `false`.
34-
- `pre_commit_config_path`: Path to pre-commit config. Default: `.pre-commit-config.yaml`.
35-
- `run_tests`: Whether to run tests. Default: `true`.
36-
- `max_repo_info_length`: Max length of repo info. Default: `10000`.
37-
- `max_unit_tests_info_length`: Max length of unit tests info. Default: `10000`.
38-
- `max_reference_info_length`: Max length of reference info. Default: `10000`.
39-
- `max_lint_info_length`: Max length of lint info. Default: `10000`.

agent/commit0_utils.py renamed to agent/agent_utils.py

Lines changed: 103 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import List
77
import fitz
8+
import yaml
89

910
from agent.class_types import AgentConfig
1011

@@ -118,24 +119,95 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
118119
return "\n".join(filter(None, tree_string))
119120

120121

121-
def get_target_edit_files(target_dir: str) -> list[str]:
122+
def collect_test_files(directory: str) -> list[str]:
123+
"""Collect all the test files in the directory."""
124+
test_files = []
125+
subdirs = []
126+
127+
# Walk through the directory
128+
for root, dirs, files in os.walk(directory):
129+
if root.endswith("/"):
130+
root = root[:-1]
131+
# Check if 'test' is part of the folder name
132+
if (
133+
"test" in os.path.basename(root).lower()
134+
or os.path.basename(root) in subdirs
135+
):
136+
for file in files:
137+
# Process only Python files
138+
if file.endswith(".py"):
139+
file_path = os.path.join(root, file)
140+
test_files.append(file_path)
141+
for d in dirs:
142+
subdirs.append(d)
143+
144+
return test_files
145+
146+
147+
def collect_python_files(directory: str) -> list[str]:
148+
"""List to store all the .py filenames"""
149+
python_files = []
150+
151+
# Walk through the directory recursively
152+
for root, _, files in os.walk(directory):
153+
for file in files:
154+
# Check if the file ends with '.py'
155+
if file.endswith(".py"):
156+
file_path = os.path.join(root, file)
157+
python_files.append(file_path)
158+
159+
return python_files
160+
161+
162+
def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str]:
163+
"""Identify files to remove content by heuristics.
164+
We assume source code is under [lib]/[lib] or [lib]/src.
165+
We exclude test code. This function would not work
166+
if test code doesn't have its own directory.
167+
168+
Args:
169+
----
170+
base_dir (str): The path to local library.
171+
src_dir (str): The directory containing source code.
172+
test_dir (str): The directory containing test code.
173+
174+
Returns:
175+
-------
176+
list[str]: A list of files to be edited.
177+
178+
"""
179+
files = collect_python_files(os.path.join(base_dir, src_dir))
180+
test_files = collect_test_files(os.path.join(base_dir, test_dir))
181+
files = list(set(files) - set(test_files))
182+
183+
# don't edit __init__ files
184+
files = [f for f in files if "__init__" not in f]
185+
# don't edit __main__ files
186+
files = [f for f in files if "__main__" not in f]
187+
# don't edit confest.py files
188+
files = [f for f in files if "conftest.py" not in f]
189+
return files
190+
191+
192+
def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[str]:
122193
"""Find the files with functions with the pass statement."""
123-
files = []
124-
for root, _, filenames in os.walk(target_dir):
125-
for filename in filenames:
126-
if filename.endswith(".py"):
127-
file_path = os.path.join(root, filename)
128-
with open(file_path, "r") as file:
129-
if " pass" in file.read():
130-
files.append(file_path)
194+
files = _find_files_to_edit(target_dir, src_dir, test_dir)
195+
filtered_files = []
196+
for file_path in files:
197+
with open(file_path, "r", encoding="utf-8", errors="ignore") as file:
198+
content = file.read()
199+
if len(content.splitlines()) > 1500:
200+
continue
201+
if " pass" in content:
202+
filtered_files.append(file_path)
131203

132204
# Remove the base_dir prefix
133-
files = [file.replace(target_dir, "").lstrip("/") for file in files]
134-
205+
filtered_files = [
206+
file.replace(target_dir, "").lstrip("/") for file in filtered_files
207+
]
135208
# Only keep python files
136-
files = [file for file in files if file.endswith(".py")]
137209

138-
return files
210+
return filtered_files
139211

140212

141213
def get_message(
@@ -288,12 +360,12 @@ def get_changed_files(repo: git.Repo) -> list[str]:
288360
return files_changed
289361

290362

291-
def get_lint_cmd(repo: git.Repo, use_lint_info: bool) -> str:
292-
"""Generate a linting command based on whether to include files changed in the latest commit.
363+
def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str:
364+
"""Generate a linting command based on whether to include files.
293365
294366
Args:
295367
----
296-
repo (git.Repo): An instance of GitPython's Repo object representing the Git repository.
368+
repo_name (str): The name of the repository.
297369
use_lint_info (bool): A flag indicating whether to include changed files in the lint command.
298370
299371
Returns:
@@ -304,7 +376,21 @@ def get_lint_cmd(repo: git.Repo, use_lint_info: bool) -> str:
304376
"""
305377
lint_cmd = "python -m commit0 lint "
306378
if use_lint_info:
307-
lint_cmd += " ".join(get_changed_files(repo))
379+
lint_cmd += repo_name + " --files "
308380
else:
309381
lint_cmd = ""
310382
return lint_cmd
383+
384+
385+
def write_agent_config(agent_config_file: str, agent_config: dict) -> None:
386+
"""Write the agent config to the file."""
387+
with open(agent_config_file, "w") as f:
388+
yaml.dump(agent_config, f)
389+
390+
391+
def read_yaml_config(config_file: str) -> dict:
392+
"""Read the yaml config from the file."""
393+
if not os.path.exists(config_file):
394+
raise FileNotFoundError(f"The config file '{config_file}' does not exist.")
395+
with open(config_file, "r") as f:
396+
return yaml.load(f, Loader=yaml.FullLoader)

agent/agents.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,51 @@
11
import sys
2-
import os
32
from abc import ABC, abstractmethod
43
from pathlib import Path
54
import logging
65

76
from aider.coders import Coder
87
from aider.models import Model
98
from aider.io import InputOutput
10-
from tenacity import retry, wait_exponential
9+
import re
10+
11+
12+
def handle_logging(logging_name: str, log_file: Path) -> None:
13+
"""Handle logging for agent"""
14+
logger = logging.getLogger(logging_name)
15+
logger.setLevel(logging.INFO)
16+
logger.propagate = False
17+
logger_handler = logging.FileHandler(log_file)
18+
logger_handler.setFormatter(
19+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
20+
)
21+
logger.addHandler(logger_handler)
22+
23+
24+
class AgentReturn(ABC):
25+
def __init__(self, log_file: Path):
26+
self.log_file = log_file
27+
self.last_cost = self.get_money_cost()
28+
29+
def get_money_cost(self) -> float:
30+
"""Get accumulated money cost from log file"""
31+
last_cost = 0.0
32+
with open(self.log_file, "r") as file:
33+
for line in file:
34+
if "Tokens:" in line and "Cost:" in line:
35+
match = re.search(
36+
r"Cost: \$\d+\.\d+ message, \$(\d+\.\d+) session", line
37+
)
38+
if match:
39+
last_cost = float(match.group(1))
40+
return last_cost
1141

1242

1343
class Agents(ABC):
1444
def __init__(self, max_iteration: int):
1545
self.max_iteration = max_iteration
1646

1747
@abstractmethod
18-
def run(self) -> None:
48+
def run(self) -> AgentReturn:
1949
"""Start agent"""
2050
raise NotImplementedError
2151

@@ -25,17 +55,14 @@ def __init__(self, max_iteration: int, model_name: str):
2555
super().__init__(max_iteration)
2656
self.model = Model(model_name)
2757

28-
@retry(
29-
wait=wait_exponential(multiplier=1, min=4, max=10),
30-
)
3158
def run(
3259
self,
3360
message: str,
3461
test_cmd: str,
3562
lint_cmd: str,
3663
fnames: list[str],
3764
log_dir: Path,
38-
) -> None:
65+
) -> AgentReturn:
3966
"""Start aider agent"""
4067
if test_cmd:
4168
auto_test = True
@@ -50,10 +77,6 @@ def run(
5077
input_history_file = log_dir / ".aider.input.history"
5178
chat_history_file = log_dir / ".aider.chat.history.md"
5279

53-
print(
54-
f"check {os.path.abspath(chat_history_file)} for prompts and lm generations",
55-
file=sys.stderr,
56-
)
5780
# Set up logging
5881
log_file = log_dir / "aider.log"
5982
logging.basicConfig(
@@ -66,15 +89,9 @@ def run(
6689
sys.stdout = open(log_file, "a")
6790
sys.stderr = open(log_file, "a")
6891

69-
# Configure httpx logging
70-
httpx_logger = logging.getLogger("httpx")
71-
httpx_logger.setLevel(logging.INFO)
72-
httpx_logger.propagate = False # Prevent propagation to root logger
73-
httpx_handler = logging.FileHandler(log_file)
74-
httpx_handler.setFormatter(
75-
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
76-
)
77-
httpx_logger.addHandler(httpx_handler)
92+
# Configure httpx and backoff logging
93+
handle_logging("httpx", log_file)
94+
handle_logging("backoff", log_file)
7895

7996
io = InputOutput(
8097
yes=True,
@@ -91,14 +108,28 @@ def run(
91108
io=io,
92109
)
93110
coder.max_reflection = self.max_iteration
94-
coder.stream = False
111+
coder.stream = True
95112

96113
# Run the agent
97114
coder.run(message)
98115

116+
# #### TMP
117+
# import time
118+
# import random
119+
120+
# time.sleep(random.random() * 5)
121+
# n = random.random() / 10
122+
# with open(log_file, "a") as f:
123+
# f.write(
124+
# f"> Tokens: 33k sent, 1.3k received. Cost: $0.12 message, ${n} session. \n"
125+
# )
126+
# #### TMP
127+
99128
# Close redirected stdout and stderr
100129
sys.stdout.close()
101130
sys.stderr.close()
102131
# Restore original stdout and stderr
103132
sys.stdout = sys.__stdout__
104133
sys.stderr = sys.__stderr__
134+
135+
return AgentReturn(log_file)

0 commit comments

Comments
 (0)