Skip to content

Commit 56aad67

Browse files
authored
Merge pull request #12 from commit-0/aider
Adding Aider pipeline for commit0
2 parents 881d931 + f1e4cf3 commit 56aad67

File tree

9 files changed

+553
-0
lines changed

9 files changed

+553
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,4 @@ cython_debug/
164164
logs/
165165
repos/
166166
config.yml
167+
hydra_outputs/

baselines/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# How to run baseline
2+
3+
Step 1: Go to `config/aider.yaml` and change the config
4+
5+
Step 2: Run the following command
6+
7+
```bash
8+
python baselines/run_aider.py
9+
```
10+
11+
## Config
12+
13+
`commit0_config`:
14+
15+
- `base_dir`: Repos dir. Default `repos`.
16+
- `dataset_name`: commit0 HF dataset name. Default: `wentingzhao/commit0_docstring`.
17+
- `dataset_split`: commit0 dataset split. Default: `test`.
18+
- `repo_split`: commit0 repo split. Default: `simpy`.
19+
- `num_workers`: number of workers to run in parallel. Default: `10`.
20+
21+
`aider_config`:
22+
23+
- `llm_name`: LLM model name. Default: `claude-3-5-sonnet-20240620`.
24+
- `use_user_prompt`: Whether to use user prompt. Default: `false`.
25+
- `user_prompt`: User prompt. Default: `""`.
26+
- `use_repo_info`: Whether to use repo info. Default: `false`.
27+
- Repo info
28+
- skeleton of the repo(filenames under each dir)
29+
- function stubs
30+
31+
- `use_unit_tests_info`: Whether to use unit tests: unit_tests that target will be tested with. Default: `false`.
32+
- `use_reference_info`: Whether to use reference: reference doc/pdf/website. Default: `false`.
33+
- `use_lint_info`: Whether to use lint: lint info. Default: `false`.
34+
- `pre_commit_config_path`: Path to pre-commit config. Default: `.pre-commit-config.yaml`.
35+
- `run_tests`: Whether to run tests. Default: `true`.
36+
- `max_repo_info_length`: Max length of repo info. Default: `10000`.
37+
- `max_unit_tests_info_length`: Max length of unit tests info. Default: `10000`.
38+
- `max_reference_info_length`: Max length of reference info. Default: `10000`.
39+
- `max_lint_info_length`: Max length of lint info. Default: `10000`.

baselines/baseline_utils.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import os
2+
import re
3+
import subprocess
4+
from pathlib import Path
5+
from typing import Any, Dict, List
6+
7+
from baselines.class_types import AiderConfig
8+
9+
PROMPT_HEADER = ">>> Here is the Task:\n"
10+
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
11+
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
12+
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
13+
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"
14+
15+
# prefix components:
16+
space = " "
17+
branch = "│ "
18+
# pointers:
19+
tee = "├── "
20+
last = "└── "
21+
22+
23+
def extract_function_stubs(file_path: Path) -> List[str]:
24+
"""Extract function stubs from a Python file, including type hints."""
25+
with open(file_path, "r") as file:
26+
content = file.read()
27+
28+
# Regular expression to match function definitions with optional type hints
29+
# This pattern now stops at the colon that ends the function signature
30+
pattern = (
31+
r"def\s+(\w+)\s*\(((?:[^()]*|\([^()]*\))*)\)\s*(?:->\s*([\w\[\],\s|]+))?\s*:"
32+
)
33+
matches = re.findall(pattern, content)
34+
35+
stubs = []
36+
for name, args, return_type in matches:
37+
# Process arguments to include type hints
38+
processed_args = []
39+
for arg in args.split(","):
40+
arg = arg.strip()
41+
if ":" in arg:
42+
arg_name, arg_type = arg.split(":", 1)
43+
processed_args.append(f"{arg_name.strip()}: {arg_type.strip()}")
44+
else:
45+
processed_args.append(arg)
46+
47+
args_str = ", ".join(processed_args)
48+
49+
# Include return type if present
50+
return_annotation = f" -> {return_type.strip()}" if return_type else ""
51+
52+
stubs.append(f"def {name}({args_str}){return_annotation}: ...")
53+
54+
return stubs
55+
56+
57+
def get_dir_info(
58+
dir_path: Path,
59+
prefix: str = "",
60+
max_depth: int = 10,
61+
include_stubs: bool = False,
62+
current_depth: int = 0,
63+
ignore_dot_files: bool = True,
64+
) -> str:
65+
"""A recursive generator, given a directory Path object will yield a visual
66+
tree structure line by line with each line prefixed by the same characters.
67+
68+
Args:
69+
----
70+
dir_path (Path): The directory to traverse
71+
prefix (str): The prefix to use for the current level
72+
max_depth (int): The maximum depth to traverse (default: infinite)
73+
current_depth (int): The current depth of traversal (used internally)
74+
ignore_dot_files (bool): Whether to ignore files/directories starting with a dot (default: True)
75+
include_stubs (bool): Whether to include function stubs for Python files (default: True)
76+
77+
"""
78+
if current_depth >= max_depth:
79+
return ""
80+
81+
contents = list(dir_path.iterdir())
82+
83+
if ignore_dot_files:
84+
contents = [c for c in contents if not c.name.startswith(".")]
85+
86+
tree_string = []
87+
# contents each get pointers that are ├── with a final └── :
88+
pointers = [tee] * (len(contents) - 1) + [last]
89+
for pointer, path in zip(pointers, contents):
90+
tree_string.append(prefix + pointer + path.name)
91+
if path.is_dir():
92+
extension = branch if pointer == tee else space
93+
tree_string.append(
94+
get_dir_info(
95+
path,
96+
prefix=prefix + extension,
97+
max_depth=max_depth,
98+
include_stubs=include_stubs,
99+
current_depth=current_depth + 1,
100+
ignore_dot_files=ignore_dot_files,
101+
)
102+
)
103+
elif include_stubs and path.suffix == ".py":
104+
stubs = extract_function_stubs(path)
105+
for stub in stubs:
106+
tree_string.append(prefix + space + space + stub)
107+
return "\n".join(filter(None, tree_string))
108+
109+
110+
def get_file_info(file_path: Path, prefix: str = "") -> str:
111+
"""Return the contents of a file with a given prefix."""
112+
tree_string = [tee + file_path.name]
113+
stubs = extract_function_stubs(file_path)
114+
for stub in stubs:
115+
tree_string.append(prefix + space + space + stub)
116+
return "\n".join(filter(None, tree_string))
117+
118+
119+
def get_prompt(file_list: str) -> str:
120+
"""Get the prompt for the Aider model."""
121+
return """Here is the Task:\n Your task is to iteratively implement the each function that is 'NotImplementedError('IMPLEMENT ME HERE')' in these files until there are no more 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests.
122+
Make sure you read the files carefully.
123+
Your output should be the edited code files.
124+
Use the above instructions to modify the supplied files: {file_list}
125+
Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.
126+
Only use standard python libraries, do not suggest installing any packages.
127+
""".format(file_list=file_list)
128+
129+
130+
def get_target_edit_files_cmd_args(target_dir: str) -> str:
131+
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
132+
HERE')'.
133+
"""
134+
# The grep command
135+
command = f"grep -R -l \"NotImplementedError('IMPLEMENT ME HERE')\" {target_dir}"
136+
137+
# Run the command and capture the output
138+
result = subprocess.run(command, shell=True, capture_output=True, text=True)
139+
140+
# Split the output into lines and remove the base_dir prefix
141+
files = result.stdout.strip().split("\n")
142+
143+
# Remove the base_dir prefix
144+
files = [file.replace(target_dir, "").lstrip("/") for file in files]
145+
146+
# Only keep python files
147+
files = [file for file in files if file.endswith(".py")]
148+
149+
return " ".join(files)
150+
151+
152+
def get_message_to_aider(
153+
aider_config: AiderConfig,
154+
target_edit_files_cmd_args: str,
155+
repo_path: str,
156+
ds: Dict[str, Any],
157+
) -> str:
158+
"""Get the message to Aider."""
159+
# support context for aider
160+
if aider_config.use_user_prompt:
161+
assert (
162+
aider_config.user_prompt != ""
163+
), "You choose to use custom user prompt, but it is empty"
164+
prompt = f"{PROMPT_HEADER} " + aider_config.user_prompt
165+
else:
166+
prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args)
167+
168+
if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
169+
unit_tests_info = (
170+
f"\n{UNIT_TESTS_INFO_HEADER} "
171+
+ get_dir_info(
172+
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
173+
prefix="",
174+
include_stubs=True,
175+
)[: aider_config.max_unit_tests_info_length]
176+
)
177+
else:
178+
unit_tests_info = ""
179+
180+
# TODO: assuming we have specification, which we currently do not have
181+
if aider_config.use_reference_info and ds["specification"]:
182+
reference = (
183+
f"\n{REFERENCE_HEADER} "
184+
+ get_reference(ds["specification"])[
185+
: aider_config.max_reference_info_length
186+
]
187+
)
188+
else:
189+
reference = ""
190+
191+
if aider_config.use_repo_info:
192+
repo_info = (
193+
f"\n{REPO_INFO_HEADER} "
194+
+ get_dir_info(
195+
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
196+
)[: aider_config.max_repo_info_length]
197+
)
198+
else:
199+
repo_info = ""
200+
201+
message_to_aider = prompt + reference + repo_info + unit_tests_info
202+
203+
return message_to_aider
204+
205+
206+
def get_reference(specification_pdf_path: str) -> str:
207+
"""Get the reference for a given specification PDF path."""
208+
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
209+
return f"/pdf {specification_pdf_path}"

baselines/class_types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class Commit0Config:
6+
base_dir: str
7+
dataset_name: str
8+
dataset_split: str
9+
repo_split: str
10+
num_workers: int
11+
12+
13+
@dataclass
14+
class AiderConfig:
15+
llm_name: str
16+
use_user_prompt: bool
17+
user_prompt: str
18+
use_repo_info: bool
19+
max_repo_info_length: int
20+
use_unit_tests_info: bool
21+
max_unit_tests_info_length: int
22+
use_reference_info: bool
23+
max_reference_info_length: int
24+
use_lint_info: bool
25+
max_lint_info_length: int
26+
pre_commit_config_path: str
27+
run_tests: bool

baselines/configs/aider.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# @package _global_
2+
defaults:
3+
- base
4+
- _self_
5+
6+
aider_config:
7+
use_user_prompt: false
8+
use_repo_info: false
9+
use_unit_tests_info: false
10+
use_reference_info: false
11+
use_lint_info: true
12+
pre_commit_config_path: .pre-commit-config.yaml
13+
run_tests: true

baselines/configs/base.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
defaults:
2+
- _self_
3+
4+
5+
6+
commit0_config:
7+
base_dir: /Users/willjiang/Desktop/ai2dev/commit0/repos
8+
dataset_name: "wentingzhao/commit0_docstring"
9+
dataset_split: "test"
10+
repo_split: "simpy"
11+
num_workers: 10
12+
13+
aider_config:
14+
llm_name: "claude-3-5-sonnet-20240620"
15+
use_user_prompt: false
16+
user_prompt: ""
17+
use_repo_info: false
18+
use_unit_tests_info: false
19+
use_reference_info: false
20+
use_lint_info: false
21+
pre_commit_config_path: .pre-commit-config.yaml
22+
run_tests: True
23+
max_repo_info_length: 10000
24+
max_unit_tests_info_length: 10000
25+
max_reference_info_length: 10000
26+
max_lint_info_length: 10000
27+
28+
hydra:
29+
run:
30+
dir: ./hydra_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

0 commit comments

Comments
 (0)