Skip to content

Adding Aider pipeline for commit0 #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@ cython_debug/
logs/
repos/
config.yml
hydra_outputs/
39 changes: 39 additions & 0 deletions baselines/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# How to run baseline

Step 1: Go to `config/aider.yaml` and change the config

Step 2: Run the following command

```bash
python baselines/run_aider.py
```

## Config

`commit0_config`:

- `base_dir`: Repos dir. Default `repos`.
- `dataset_name`: commit0 HF dataset name. Default: `wentingzhao/commit0_docstring`.
- `dataset_split`: commit0 dataset split. Default: `test`.
- `repo_split`: commit0 repo split. Default: `simpy`.
- `num_workers`: number of workers to run in parallel. Default: `10`.

`aider_config`:

- `llm_name`: LLM model name. Default: `claude-3-5-sonnet-20240620`.
- `use_user_prompt`: Whether to use user prompt. Default: `false`.
- `user_prompt`: User prompt. Default: `""`.
- `use_repo_info`: Whether to use repo info. Default: `false`.
- Repo info
- skeleton of the repo(filenames under each dir)
- function stubs

- `use_unit_tests_info`: Whether to use unit tests: unit_tests that target will be tested with. Default: `false`.
- `use_reference_info`: Whether to use reference: reference doc/pdf/website. Default: `false`.
- `use_lint_info`: Whether to use lint: lint info. Default: `false`.
- `pre_commit_config_path`: Path to pre-commit config. Default: `.pre-commit-config.yaml`.
- `run_tests`: Whether to run tests. Default: `true`.
- `max_repo_info_length`: Max length of repo info. Default: `10000`.
- `max_unit_tests_info_length`: Max length of unit tests info. Default: `10000`.
- `max_reference_info_length`: Max length of reference info. Default: `10000`.
- `max_lint_info_length`: Max length of lint info. Default: `10000`.
209 changes: 209 additions & 0 deletions baselines/baseline_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
import re
import subprocess
from pathlib import Path
from typing import Any, Dict, List

from baselines.class_types import AiderConfig

PROMPT_HEADER = ">>> Here is the Task:\n"
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"

# prefix components:
space = " "
branch = "│ "
# pointers:
tee = "├── "
last = "└── "


def extract_function_stubs(file_path: Path) -> List[str]:
"""Extract function stubs from a Python file, including type hints."""
with open(file_path, "r") as file:
content = file.read()

# Regular expression to match function definitions with optional type hints
# This pattern now stops at the colon that ends the function signature
pattern = (
r"def\s+(\w+)\s*\(((?:[^()]*|\([^()]*\))*)\)\s*(?:->\s*([\w\[\],\s|]+))?\s*:"
)
matches = re.findall(pattern, content)

stubs = []
for name, args, return_type in matches:
# Process arguments to include type hints
processed_args = []
for arg in args.split(","):
arg = arg.strip()
if ":" in arg:
arg_name, arg_type = arg.split(":", 1)
processed_args.append(f"{arg_name.strip()}: {arg_type.strip()}")
else:
processed_args.append(arg)

args_str = ", ".join(processed_args)

# Include return type if present
return_annotation = f" -> {return_type.strip()}" if return_type else ""

stubs.append(f"def {name}({args_str}){return_annotation}: ...")

return stubs


def get_dir_info(
dir_path: Path,
prefix: str = "",
max_depth: int = 10,
include_stubs: bool = False,
current_depth: int = 0,
ignore_dot_files: bool = True,
) -> str:
"""A recursive generator, given a directory Path object will yield a visual
tree structure line by line with each line prefixed by the same characters.

Args:
----
dir_path (Path): The directory to traverse
prefix (str): The prefix to use for the current level
max_depth (int): The maximum depth to traverse (default: infinite)
current_depth (int): The current depth of traversal (used internally)
ignore_dot_files (bool): Whether to ignore files/directories starting with a dot (default: True)
include_stubs (bool): Whether to include function stubs for Python files (default: True)

"""
if current_depth >= max_depth:
return ""

contents = list(dir_path.iterdir())

if ignore_dot_files:
contents = [c for c in contents if not c.name.startswith(".")]

tree_string = []
# contents each get pointers that are ├── with a final └── :
pointers = [tee] * (len(contents) - 1) + [last]
for pointer, path in zip(pointers, contents):
tree_string.append(prefix + pointer + path.name)
if path.is_dir():
extension = branch if pointer == tee else space
tree_string.append(
get_dir_info(
path,
prefix=prefix + extension,
max_depth=max_depth,
include_stubs=include_stubs,
current_depth=current_depth + 1,
ignore_dot_files=ignore_dot_files,
)
)
elif include_stubs and path.suffix == ".py":
stubs = extract_function_stubs(path)
for stub in stubs:
tree_string.append(prefix + space + space + stub)
return "\n".join(filter(None, tree_string))


def get_file_info(file_path: Path, prefix: str = "") -> str:
"""Return the contents of a file with a given prefix."""
tree_string = [tee + file_path.name]
stubs = extract_function_stubs(file_path)
for stub in stubs:
tree_string.append(prefix + space + space + stub)
return "\n".join(filter(None, tree_string))


def get_prompt(file_list: str) -> str:
"""Get the prompt for the Aider model."""
return """Here is the Task:\n Your task is to iteratively implement the each function that is 'NotImplementedError('IMPLEMENT ME HERE')' in these files until there are no more 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests.
Make sure you read the files carefully.
Your output should be the edited code files.
Use the above instructions to modify the supplied files: {file_list}
Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.
Only use standard python libraries, do not suggest installing any packages.
""".format(file_list=file_list)


def get_target_edit_files_cmd_args(target_dir: str) -> str:
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
HERE')'.
"""
# The grep command
command = f"grep -R -l \"NotImplementedError('IMPLEMENT ME HERE')\" {target_dir}"

# Run the command and capture the output
result = subprocess.run(command, shell=True, capture_output=True, text=True)

# Split the output into lines and remove the base_dir prefix
files = result.stdout.strip().split("\n")

# Remove the base_dir prefix
files = [file.replace(target_dir, "").lstrip("/") for file in files]

# Only keep python files
files = [file for file in files if file.endswith(".py")]

return " ".join(files)


def get_message_to_aider(
aider_config: AiderConfig,
target_edit_files_cmd_args: str,
repo_path: str,
ds: Dict[str, Any],
) -> str:
"""Get the message to Aider."""
# support context for aider
if aider_config.use_user_prompt:
assert (
aider_config.user_prompt != ""
), "You choose to use custom user prompt, but it is empty"
prompt = f"{PROMPT_HEADER} " + aider_config.user_prompt
else:
prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args)

if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
unit_tests_info = (
f"\n{UNIT_TESTS_INFO_HEADER} "
+ get_dir_info(
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
prefix="",
include_stubs=True,
)[: aider_config.max_unit_tests_info_length]
)
else:
unit_tests_info = ""

# TODO: assuming we have specification, which we currently do not have
if aider_config.use_reference_info and ds["specification"]:
reference = (
f"\n{REFERENCE_HEADER} "
+ get_reference(ds["specification"])[
: aider_config.max_reference_info_length
]
)
else:
reference = ""

if aider_config.use_repo_info:
repo_info = (
f"\n{REPO_INFO_HEADER} "
+ get_dir_info(
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
)[: aider_config.max_repo_info_length]
)
else:
repo_info = ""

message_to_aider = prompt + reference + repo_info + unit_tests_info

return message_to_aider


def get_reference(specification_pdf_path: str) -> str:
"""Get the reference for a given specification PDF path."""
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
return f"/pdf {specification_pdf_path}"
27 changes: 27 additions & 0 deletions baselines/class_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass


@dataclass
class Commit0Config:
base_dir: str
dataset_name: str
dataset_split: str
repo_split: str
num_workers: int


@dataclass
class AiderConfig:
llm_name: str
use_user_prompt: bool
user_prompt: str
use_repo_info: bool
max_repo_info_length: int
use_unit_tests_info: bool
max_unit_tests_info_length: int
use_reference_info: bool
max_reference_info_length: int
use_lint_info: bool
max_lint_info_length: int
pre_commit_config_path: str
run_tests: bool
13 changes: 13 additions & 0 deletions baselines/configs/aider.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _global_
defaults:
- base
- _self_

aider_config:
use_user_prompt: false
use_repo_info: false
use_unit_tests_info: false
use_reference_info: false
use_lint_info: true
pre_commit_config_path: .pre-commit-config.yaml
run_tests: true
30 changes: 30 additions & 0 deletions baselines/configs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
defaults:
- _self_



commit0_config:
base_dir: /Users/willjiang/Desktop/ai2dev/commit0/repos
dataset_name: "wentingzhao/commit0_docstring"
dataset_split: "test"
repo_split: "simpy"
num_workers: 10

aider_config:
llm_name: "claude-3-5-sonnet-20240620"
use_user_prompt: false
user_prompt: ""
use_repo_info: false
use_unit_tests_info: false
use_reference_info: false
use_lint_info: false
pre_commit_config_path: .pre-commit-config.yaml
run_tests: True
max_repo_info_length: 10000
max_unit_tests_info_length: 10000
max_reference_info_length: 10000
max_lint_info_length: 10000

hydra:
run:
dir: ./hydra_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Loading
Loading