Skip to content

Commit 1904f27

Browse files
committed
aider v1
1 parent 07fccd1 commit 1904f27

File tree

5 files changed

+47
-31
lines changed

5 files changed

+47
-31
lines changed

baselines/baseline_utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def get_target_edit_files_cmd_args(target_dir: str) -> str:
140140
# Split the output into lines and remove the base_dir prefix
141141
files = result.stdout.strip().split("\n")
142142

143+
# Remove the base_dir prefix
144+
files = [file.replace(target_dir, "").lstrip("/") for file in files]
145+
146+
# Only keep python files
147+
files = [file for file in files if file.endswith(".py")]
148+
143149
return " ".join(files)
144150

145151

@@ -186,17 +192,7 @@ def get_message_to_aider(
186192
else:
187193
repo_info = ""
188194

189-
if aider_config.use_lint_info:
190-
lint_info = (
191-
f"\n{LINT_INFO_HEADER} "
192-
+ subprocess.run(
193-
["pre-commit", "run", "--all-files"], capture_output=True, text=True
194-
).stdout
195-
)[: aider_config.max_lint_info_length]
196-
else:
197-
lint_info = ""
198-
199-
message_to_aider = prompt + reference + repo_info + unit_tests_info + lint_info
195+
message_to_aider = prompt + reference + repo_info + unit_tests_info
200196

201197
return message_to_aider
202198

baselines/class_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
class Commit0Config(BaseModel):
77
base_dir: str
88
dataset_name: str
9+
repo_split: str
910

1011

1112
class AiderConfig(BaseModel):

baselines/config/aider.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ aider_config:
77
use_repo_info: false
88
use_unit_tests_info: false
99
use_reference_info: false
10-
use_lint_info: false
10+
use_lint_info: true
11+
pre_commit_config_path: .pre-commit-config.yaml

baselines/config/base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ defaults:
66
commit0_config:
77
base_dir: /Users/willjiang/Desktop/ai2dev/commit0/repos
88
dataset_name: "wentingzhao/commit0_docstring"
9+
repo_split: "lite"
910

1011
aider_config:
1112
llm_name: "claude-3-5-sonnet-20240620"
@@ -17,6 +18,7 @@ aider_config:
1718
max_reference_info_length: 10000
1819
use_lint_info: false
1920
max_lint_info_length: 10000
21+
pre_commit_config_path: .pre-commit-config.yaml
2022

2123
hydra:
2224
run:

baselines/run_aider.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
import logging
22
import os
33
import subprocess
4-
from functools import partial
54

65
import hydra
76
from datasets import load_dataset
87
from omegaconf import OmegaConf
9-
from tqdm.contrib.concurrent import thread_map
108
import tarfile
119
from baselines.baseline_utils import (
1210
get_message_to_aider,
1311
get_target_edit_files_cmd_args,
1412
)
1513
from baselines.class_types import AiderConfig, BaselineConfig, Commit0Config
16-
14+
from commit0.harness.constants import SPLIT
1715
# from aider.run_aider import get_aider_cmd
1816

1917
logging.basicConfig(
@@ -27,11 +25,16 @@ def get_aider_cmd(
2725
files: str,
2826
message_to_aider: str,
2927
test_cmd: str,
28+
lint_cmd: str,
3029
) -> str:
3130
"""Get the Aider command based on the given context."""
32-
aider_cmd = f"aider --model {model} --file {files} --message \"{message_to_aider}\" --auto-test --test --test-cmd '{test_cmd}' --yes"
33-
34-
return aider_cmd
31+
base_cmd = f'aider --model {model} --file {files} --message "{message_to_aider}"'
32+
if lint_cmd:
33+
base_cmd += f" --auto-lint --lint-cmd '{lint_cmd}'"
34+
if test_cmd:
35+
base_cmd += f" --auto-test --test --test-cmd '{test_cmd}'"
36+
base_cmd += " --yes"
37+
return base_cmd
3538

3639

3740
def run_aider_for_repo(
@@ -61,29 +64,39 @@ def run_aider_for_repo(
6164

6265
repo_path = os.path.join(commit0_config.base_dir, repo_name)
6366

67+
os.chdir(repo_path)
68+
6469
target_edit_files_cmd_args = get_target_edit_files_cmd_args(repo_path)
6570

6671
message_to_aider = get_message_to_aider(
6772
aider_config, target_edit_files_cmd_args, repo_path, ds
6873
)
6974

70-
test_files = test_files[:1]
75+
if aider_config.use_lint_info:
76+
lint_cmd = "pre-commit run --config ../../.pre-commit-config.yaml --files"
77+
else:
78+
lint_cmd = ""
79+
7180
for test_file in test_files:
72-
test_cmd = f"uv run commit0 test-reference {repo_name} {test_file}"
81+
test_cmd = f"python -m commit0 test {repo_name} {test_file}"
7382

7483
aider_cmd = get_aider_cmd(
7584
aider_config.llm_name,
7685
target_edit_files_cmd_args,
7786
message_to_aider,
7887
test_cmd,
88+
lint_cmd,
7989
)
8090

81-
print(aider_cmd)
82-
8391
try:
84-
process = subprocess.Popen(aider_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
92+
process = subprocess.Popen(
93+
aider_cmd,
94+
shell=True,
95+
stdout=subprocess.PIPE,
96+
stderr=subprocess.PIPE,
97+
universal_newlines=True,
98+
)
8599
stdout, stderr = process.communicate()
86-
results = process.returncode
87100
logger.info(f"STDOUT: {stdout}")
88101
logger.info(f"STDERR: {stderr}")
89102
except subprocess.CalledProcessError as e:
@@ -97,6 +110,7 @@ def run_aider_for_repo(
97110
logger.error(f"Command: {''.join(aider_cmd)}")
98111
else:
99112
logger.error(f"OSError occurred: {e}")
113+
asdf
100114

101115

102116
@hydra.main(version_base=None, config_path="config", config_name="aider")
@@ -114,13 +128,15 @@ def main(config: BaselineConfig) -> None:
114128

115129
dataset = load_dataset(commit0_config.dataset_name, split="test")
116130

117-
dataset = [dataset[3]]
118-
thread_map(
119-
partial(run_aider_for_repo, commit0_config, aider_config),
120-
dataset,
121-
desc="Running aider for repos",
122-
max_workers=10,
123-
)
131+
filtered_dataset = [
132+
example
133+
for example in dataset
134+
if commit0_config.repo_split == "all"
135+
or example["repo"].split("/")[-1] in SPLIT.get(commit0_config.repo_split, [])
136+
]
137+
138+
for example in filtered_dataset:
139+
run_aider_for_repo(commit0_config, aider_config, example)
124140

125141

126142
if __name__ == "__main__":

0 commit comments

Comments
 (0)