Skip to content

Fix docker running error #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test_fiel support
  • Loading branch information
nanjiangwill committed Sep 12, 2024
commit 953f4b600c23d1ac0ba461669b2be0db16ad1a6f
47 changes: 35 additions & 12 deletions baselines/baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
EDIT_HISTORY_HEADER = "\n\n>>> Here is the Edit History:\n"
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"

# prefix components:
space = " "
Expand Down Expand Up @@ -154,31 +154,54 @@ def get_message_to_aider(
prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args)

if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
unit_tests_info = f"\n{UNIT_TESTS_INFO_HEADER} " + get_dir_info(
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
prefix="",
include_stubs=True,
unit_tests_info = (
f"\n{UNIT_TESTS_INFO_HEADER} "
+ get_dir_info(
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
prefix="",
include_stubs=True,
)[: aider_config.max_unit_tests_info_length]
)
else:
unit_tests_info = ""

# TODO: assuming we have specification, which we currently do not have
if aider_config.use_reference_info and ds["specification"]:
reference = f"\n{REFERENCE_HEADER} " + get_reference(ds["specification"])
reference = (
f"\n{REFERENCE_HEADER} "
+ get_reference(ds["specification"])[
: aider_config.max_reference_info_length
]
)
else:
reference = ""

if aider_config.use_repo_info:
repo_info = f"\n{REPO_INFO_HEADER} " + get_dir_info(
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
repo_info = (
f"\n{REPO_INFO_HEADER} "
+ get_dir_info(
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
)[: aider_config.max_repo_info_length]
)
else:
repo_info = ""

message_to_aider = prompt + reference + repo_info + unit_tests_info
if aider_config.use_lint_info:
lint_info = (
f"\n{LINT_INFO_HEADER} "
+ subprocess.run(
["pre-commit", "run", "--all-files"], capture_output=True, text=True
).stdout
)[: aider_config.max_lint_info_length]
else:
lint_info = ""

message_to_aider = prompt + reference + repo_info + unit_tests_info + lint_info

return message_to_aider


def get_reference(specification_url: str) -> str:
"""Get the reference for a given specification URL."""
return f"/web {specification_url}"
def get_reference(specification_pdf_path: str) -> str:
"""Get the reference for a given specification PDF path."""
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
return f"/pdf {specification_pdf_path}"
7 changes: 6 additions & 1 deletion baselines/class_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ class Commit0Config(BaseModel):
class AiderConfig(BaseModel):
llm_name: str
use_repo_info: bool
max_repo_info_length: int
use_unit_tests_info: bool
max_unit_tests_info_length: int
use_reference_info: bool
max_reference_info_length: int
use_lint_info: bool
max_lint_info_length: int


class BaselineConfig(BaseModel):
config: Dict[str, Dict[str, Union[str, bool]]]
config: Dict[str, Dict[str, Union[str, bool, int]]]

commit0_config: Commit0Config | None = None
aider_config: AiderConfig | None = None
Expand Down
7 changes: 4 additions & 3 deletions baselines/config/aider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defaults:
- _self_

aider_config:
use_repo_info: true
use_unit_tests_info: true
use_reference_info: false
use_repo_info: false
use_unit_tests_info: false
use_reference_info: false
use_lint_info: true
5 changes: 5 additions & 0 deletions baselines/config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ commit0_config:
aider_config:
llm_name: "claude-3-5-sonnet-20240620"
use_repo_info: false
max_repo_info_length: 10000
use_unit_tests_info: false
max_unit_tests_info_length: 10000
use_reference_info: false
max_reference_info_length: 10000
use_lint_info: false
max_lint_info_length: 10000

hydra:
run:
Expand Down
16 changes: 13 additions & 3 deletions baselines/run_aider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datasets import load_dataset
from omegaconf import OmegaConf
from tqdm.contrib.concurrent import thread_map

import tarfile
from baselines.baseline_utils import (
get_message_to_aider,
get_target_edit_files_cmd_args,
Expand Down Expand Up @@ -46,8 +46,18 @@ def run_aider_for_repo(
# get repo info
_, repo_name = ds["repo"].split("/")

# TODO: assuming we have all test_files, which we currently do not have
test_files = ds["test_files"]
repo_name = repo_name.lower()
repo_name = repo_name.replace(".", "-")
with tarfile.open(f"commit0/data/test_ids/{repo_name}.tar.bz2", "r:bz2") as tar:
for member in tar.getmembers():
if member.isfile():
file = tar.extractfile(member)
if file:
test_files_str = file.read().decode("utf-8")
# print(content.decode("utf-8"))

test_files = test_files_str.split("\n") if isinstance(test_files_str, str) else []
test_files = sorted(list(set([i.split(":")[0] for i in test_files])))

repo_path = os.path.join(commit0_config.base_dir, repo_name)

Expand Down
Loading