|
| 1 | +import os |
| 2 | +import re |
| 3 | +import subprocess |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Dict, List |
| 6 | + |
| 7 | +from baselines.class_types import AiderConfig |
| 8 | + |
| 9 | +PROMPT_HEADER = ">>> Here is the Task:\n" |
| 10 | +REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n" |
| 11 | +REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n" |
| 12 | +UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n" |
| 13 | +LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n" |
| 14 | + |
| 15 | +# prefix components: |
| 16 | +space = " " |
| 17 | +branch = "│ " |
| 18 | +# pointers: |
| 19 | +tee = "├── " |
| 20 | +last = "└── " |
| 21 | + |
| 22 | + |
| 23 | +def extract_function_stubs(file_path: Path) -> List[str]: |
| 24 | + """Extract function stubs from a Python file, including type hints.""" |
| 25 | + with open(file_path, "r") as file: |
| 26 | + content = file.read() |
| 27 | + |
| 28 | + # Regular expression to match function definitions with optional type hints |
| 29 | + # This pattern now stops at the colon that ends the function signature |
| 30 | + pattern = ( |
| 31 | + r"def\s+(\w+)\s*\(((?:[^()]*|\([^()]*\))*)\)\s*(?:->\s*([\w\[\],\s|]+))?\s*:" |
| 32 | + ) |
| 33 | + matches = re.findall(pattern, content) |
| 34 | + |
| 35 | + stubs = [] |
| 36 | + for name, args, return_type in matches: |
| 37 | + # Process arguments to include type hints |
| 38 | + processed_args = [] |
| 39 | + for arg in args.split(","): |
| 40 | + arg = arg.strip() |
| 41 | + if ":" in arg: |
| 42 | + arg_name, arg_type = arg.split(":", 1) |
| 43 | + processed_args.append(f"{arg_name.strip()}: {arg_type.strip()}") |
| 44 | + else: |
| 45 | + processed_args.append(arg) |
| 46 | + |
| 47 | + args_str = ", ".join(processed_args) |
| 48 | + |
| 49 | + # Include return type if present |
| 50 | + return_annotation = f" -> {return_type.strip()}" if return_type else "" |
| 51 | + |
| 52 | + stubs.append(f"def {name}({args_str}){return_annotation}: ...") |
| 53 | + |
| 54 | + return stubs |
| 55 | + |
| 56 | + |
| 57 | +def get_dir_info( |
| 58 | + dir_path: Path, |
| 59 | + prefix: str = "", |
| 60 | + max_depth: int = 10, |
| 61 | + include_stubs: bool = False, |
| 62 | + current_depth: int = 0, |
| 63 | + ignore_dot_files: bool = True, |
| 64 | +) -> str: |
| 65 | + """A recursive generator, given a directory Path object will yield a visual |
| 66 | + tree structure line by line with each line prefixed by the same characters. |
| 67 | +
|
| 68 | + Args: |
| 69 | + ---- |
| 70 | + dir_path (Path): The directory to traverse |
| 71 | + prefix (str): The prefix to use for the current level |
| 72 | + max_depth (int): The maximum depth to traverse (default: infinite) |
| 73 | + current_depth (int): The current depth of traversal (used internally) |
| 74 | + ignore_dot_files (bool): Whether to ignore files/directories starting with a dot (default: True) |
| 75 | + include_stubs (bool): Whether to include function stubs for Python files (default: True) |
| 76 | +
|
| 77 | + """ |
| 78 | + if current_depth >= max_depth: |
| 79 | + return "" |
| 80 | + |
| 81 | + contents = list(dir_path.iterdir()) |
| 82 | + |
| 83 | + if ignore_dot_files: |
| 84 | + contents = [c for c in contents if not c.name.startswith(".")] |
| 85 | + |
| 86 | + tree_string = [] |
| 87 | + # contents each get pointers that are ├── with a final └── : |
| 88 | + pointers = [tee] * (len(contents) - 1) + [last] |
| 89 | + for pointer, path in zip(pointers, contents): |
| 90 | + tree_string.append(prefix + pointer + path.name) |
| 91 | + if path.is_dir(): |
| 92 | + extension = branch if pointer == tee else space |
| 93 | + tree_string.append( |
| 94 | + get_dir_info( |
| 95 | + path, |
| 96 | + prefix=prefix + extension, |
| 97 | + max_depth=max_depth, |
| 98 | + include_stubs=include_stubs, |
| 99 | + current_depth=current_depth + 1, |
| 100 | + ignore_dot_files=ignore_dot_files, |
| 101 | + ) |
| 102 | + ) |
| 103 | + elif include_stubs and path.suffix == ".py": |
| 104 | + stubs = extract_function_stubs(path) |
| 105 | + for stub in stubs: |
| 106 | + tree_string.append(prefix + space + space + stub) |
| 107 | + return "\n".join(filter(None, tree_string)) |
| 108 | + |
| 109 | + |
| 110 | +def get_file_info(file_path: Path, prefix: str = "") -> str: |
| 111 | + """Return the contents of a file with a given prefix.""" |
| 112 | + tree_string = [tee + file_path.name] |
| 113 | + stubs = extract_function_stubs(file_path) |
| 114 | + for stub in stubs: |
| 115 | + tree_string.append(prefix + space + space + stub) |
| 116 | + return "\n".join(filter(None, tree_string)) |
| 117 | + |
| 118 | + |
| 119 | +def get_prompt(file_list: str) -> str: |
| 120 | + """Get the prompt for the Aider model.""" |
| 121 | + return """Here is the Task:\n Your task is to iteratively implement the each function that is 'NotImplementedError('IMPLEMENT ME HERE')' in these files until there are no more 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests. |
| 122 | +Make sure you read the files carefully. |
| 123 | +Your output should be the edited code files. |
| 124 | +Use the above instructions to modify the supplied files: {file_list} |
| 125 | +Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc. |
| 126 | +Only use standard python libraries, do not suggest installing any packages. |
| 127 | +""".format(file_list=file_list) |
| 128 | + |
| 129 | + |
| 130 | +def get_target_edit_files_cmd_args(target_dir: str) -> str: |
| 131 | + """Find the files with the error 'NotImplementedError('IMPLEMENT ME |
| 132 | + HERE')'. |
| 133 | + """ |
| 134 | + # The grep command |
| 135 | + command = f"grep -R -l \"NotImplementedError('IMPLEMENT ME HERE')\" {target_dir}" |
| 136 | + |
| 137 | + # Run the command and capture the output |
| 138 | + result = subprocess.run(command, shell=True, capture_output=True, text=True) |
| 139 | + |
| 140 | + # Split the output into lines and remove the base_dir prefix |
| 141 | + files = result.stdout.strip().split("\n") |
| 142 | + |
| 143 | + # Remove the base_dir prefix |
| 144 | + files = [file.replace(target_dir, "").lstrip("/") for file in files] |
| 145 | + |
| 146 | + # Only keep python files |
| 147 | + files = [file for file in files if file.endswith(".py")] |
| 148 | + |
| 149 | + return " ".join(files) |
| 150 | + |
| 151 | + |
| 152 | +def get_message_to_aider( |
| 153 | + aider_config: AiderConfig, |
| 154 | + target_edit_files_cmd_args: str, |
| 155 | + repo_path: str, |
| 156 | + ds: Dict[str, Any], |
| 157 | +) -> str: |
| 158 | + """Get the message to Aider.""" |
| 159 | + # support context for aider |
| 160 | + if aider_config.use_user_prompt: |
| 161 | + assert ( |
| 162 | + aider_config.user_prompt != "" |
| 163 | + ), "You choose to use custom user prompt, but it is empty" |
| 164 | + prompt = f"{PROMPT_HEADER} " + aider_config.user_prompt |
| 165 | + else: |
| 166 | + prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args) |
| 167 | + |
| 168 | + if aider_config.use_unit_tests_info and ds["test"]["test_dir"]: |
| 169 | + unit_tests_info = ( |
| 170 | + f"\n{UNIT_TESTS_INFO_HEADER} " |
| 171 | + + get_dir_info( |
| 172 | + dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])), |
| 173 | + prefix="", |
| 174 | + include_stubs=True, |
| 175 | + )[: aider_config.max_unit_tests_info_length] |
| 176 | + ) |
| 177 | + else: |
| 178 | + unit_tests_info = "" |
| 179 | + |
| 180 | + # TODO: assuming we have specification, which we currently do not have |
| 181 | + if aider_config.use_reference_info and ds["specification"]: |
| 182 | + reference = ( |
| 183 | + f"\n{REFERENCE_HEADER} " |
| 184 | + + get_reference(ds["specification"])[ |
| 185 | + : aider_config.max_reference_info_length |
| 186 | + ] |
| 187 | + ) |
| 188 | + else: |
| 189 | + reference = "" |
| 190 | + |
| 191 | + if aider_config.use_repo_info: |
| 192 | + repo_info = ( |
| 193 | + f"\n{REPO_INFO_HEADER} " |
| 194 | + + get_dir_info( |
| 195 | + dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False |
| 196 | + )[: aider_config.max_repo_info_length] |
| 197 | + ) |
| 198 | + else: |
| 199 | + repo_info = "" |
| 200 | + |
| 201 | + message_to_aider = prompt + reference + repo_info + unit_tests_info |
| 202 | + |
| 203 | + return message_to_aider |
| 204 | + |
| 205 | + |
| 206 | +def get_reference(specification_pdf_path: str) -> str: |
| 207 | + """Get the reference for a given specification PDF path.""" |
| 208 | + # TODO: after pdf_to_text is available, use it to extract the text from the PDF |
| 209 | + return f"/pdf {specification_pdf_path}" |
0 commit comments