Skip to content

Commit 8d710f7

Browse files
committed
precommit?
1 parent 1092765 commit 8d710f7

File tree

1 file changed

+74
-45
lines changed

1 file changed

+74
-45
lines changed

docs/render_submissions.py

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from commit0.cli import write_commit0_config_file
1919

2020
import logging
21+
from typing import Any, NoReturn
2122

2223
logging.basicConfig(
2324
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -27,9 +28,13 @@
2728
analysis_files_path = "/share/rush/commit0_analysis_temp"
2829

2930

30-
def get_pytest_info(path_to_logs, repo_name, branch_name):
31+
def get_pytest_info(
32+
path_to_logs: str, repo_name: str, branch_name: str
33+
) -> dict[str, dict[str, Any]]:
3134
pytest_info = {}
3235
for pytest_hash in os.listdir(path_to_logs):
36+
if not os.path.exists(os.path.join(path_to_logs, pytest_hash, "eval.sh")):
37+
continue
3338
eval_script = open(os.path.join(path_to_logs, pytest_hash, "eval.sh")).read()
3439
testname = re.search(r"([\S]+) > test_output", eval_script).group(1)
3540
patch_diff = open(os.path.join(path_to_logs, pytest_hash, "patch.diff")).read()
@@ -85,19 +90,19 @@ def get_pytest_info(path_to_logs, repo_name, branch_name):
8590
"failure_string": failure_string,
8691
"duration": duration,
8792
}
88-
return pytest_info
93+
return pytest_info if len(pytest_info) else "Could not evaluate"
8994

9095

91-
def get_coverage_info(path_to_logs, repo_name, branch_name):
96+
def get_coverage_info(path_to_logs: str, repo_name: str, branch_name: str) -> Any:
9297
raise NotImplementedError
9398

9499

95100
def get_blank_repo_metrics(
96-
blank_source_code_folder,
97-
spec_filename,
101+
blank_source_code_folder: str,
102+
spec_filename: str,
98103
tokenizer,
99104
code_file_filter=lambda filename: filename,
100-
):
105+
) -> dict[str, Any]:
101106
blank_repo_metrics = {
102107
"functions_to_edit": [],
103108
}
@@ -165,7 +170,7 @@ def get_blank_repo_metrics(
165170

166171

167172
leaderboard_header = """\n\n## Leaderboard ({split})
168-
| Name | Repos Resolved (/{num_repos}) | Tests Passed (Total: {total_num_tests}) | Test Duration (s) | Date | Analysis | Github |
173+
| Name | Repos Resolved (/{num_repos}) | Avg. pass rate | Test Duration (s) | Date | Analysis | Github |
169174
|------|:-------------------------:|:--------------------:|:--------------------:|:----------:|----|----| """
170175

171176
submission_table_header = """# Submission Name: **{display_name}** (split: {split})
@@ -179,7 +184,7 @@ def get_blank_repo_metrics(
179184
"""
180185

181186

182-
def render_mds(overwrite_previous, subfolder="docs"):
187+
def render_mds(overwrite_previous: bool, subfolder: str = "docs") -> NoReturn:
183188
leaderboard = {}
184189

185190
split_to_total_tests = {
@@ -193,11 +198,16 @@ def render_mds(overwrite_previous, subfolder="docs"):
193198
# repo_tests = subprocess.run(['commit0', 'get-tests', repo_name], capture_output=True, text=True).stdout.strip()
194199
# total_num_tests += len(repo_tests.splitlines())
195200
leaderboard[split] = []
196-
leaderboard[split].append((split_to_total_tests[split]+1, leaderboard_header.format(
197-
split=split,
198-
num_repos=num_repos,
199-
total_num_tests=split_to_total_tests[split],
200-
)))
201+
leaderboard[split].append(
202+
(
203+
split_to_total_tests[split] + 1,
204+
leaderboard_header.format(
205+
split=split,
206+
num_repos=num_repos,
207+
total_num_tests=split_to_total_tests[split],
208+
),
209+
)
210+
)
201211

202212
for org_path in tqdm.tqdm(glob.glob(os.path.join(analysis_files_path, "*"))):
203213
org_name = os.path.basename(org_path)
@@ -241,7 +251,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
241251
subfolder, f"analysis_{org_name}_{branch_name}_{repo_name}.md"
242252
)
243253
if isinstance(repo_pytest_results, str):
244-
submission_repo_page = f"# **{display_name}**: {repo_name}\n\n## Failed to clone\n\n{repo_pytest_results}"
254+
submission_repo_page = f"# **{display_name}**: {repo_name}\n\n## Failed\n\n{repo_pytest_results}"
245255
org_branch_repo_filepath = os.path.join(
246256
subfolder, f"analysis_{org_name}_{branch_name}_{repo_name}.md"
247257
)
@@ -253,7 +263,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
253263
submission_page = submission_table_header.format(
254264
display_name=display_name, split=split
255265
) + (
256-
f"\n| {repo_name} | No; Failed to clone. | - | - | "
266+
f"\n| {repo_name} | No; {repo_pytest_results} | - | - | "
257267
f"[Analysis](/{f'analysis_{org_name}_{branch_name}_{repo_name}'}) | "
258268
f"[Github]({github_hyperlink}) |"
259269
)
@@ -274,16 +284,23 @@ def render_mds(overwrite_previous, subfolder="docs"):
274284
)
275285
pytest_details = "Pytest failed"
276286
duration = "Failed."
277-
evaluate_numbers.append(0.)
278-
if split == "all" and repo_name in SPLIT['lite']:
279-
lite_evaluate_numbers.append(0.)
287+
evaluate_numbers.append(0.0)
288+
if split == "all" and repo_name in SPLIT["lite"]:
289+
lite_evaluate_numbers.append(0.0)
280290
else:
281291
resolved = False
282292
if "passed" in pytest_info["summary"]:
283293
if "skipped" in pytest_info["summary"]:
284-
resolved = pytest_info["summary"]["passed"] + pytest_info["summary"]["skipped"] == pytest_info["summary"]["total"]
294+
resolved = (
295+
pytest_info["summary"]["passed"]
296+
+ pytest_info["summary"]["skipped"]
297+
== pytest_info["summary"]["total"]
298+
)
285299
else:
286-
resolved = pytest_info["summary"]["passed"] == pytest_info["summary"]["total"]
300+
resolved = (
301+
pytest_info["summary"]["passed"]
302+
== pytest_info["summary"]["total"]
303+
)
287304
if write_submission:
288305
submission_repo_page += pytest_summary_table_header.format(
289306
pytest_group=pytest_group
@@ -307,11 +324,15 @@ def render_mds(overwrite_previous, subfolder="docs"):
307324
)
308325
# cum_tests_passed += pytest_info["summary"]["passed"]
309326
num_tests = len(get_tests(repo_name, verbose=0))
310-
evaluate_numbers.append(pytest_info["summary"]["passed"] / num_tests)
327+
evaluate_numbers.append(
328+
pytest_info["summary"]["passed"] / num_tests
329+
)
311330
total_duration += pytest_info["duration"]
312331
repos_resolved += int(resolved)
313-
if split == "all" and repo_name in SPLIT['lite']:
314-
lite_evaluate_numbers.append(pytest_info["summary"]["passed"] / num_tests)
332+
if split == "all" and repo_name in SPLIT["lite"]:
333+
lite_evaluate_numbers.append(
334+
pytest_info["summary"]["passed"] / num_tests
335+
)
315336
# lite_cum_tests_passed += pytest_info["summary"]["passed"]
316337
lite_total_duration += pytest_info["duration"]
317338
lite_repos_resolved += int(resolved)
@@ -341,26 +362,34 @@ def render_mds(overwrite_previous, subfolder="docs"):
341362
analysis_link = f"[Analysis](/{f'analysis_{org_name}_{branch_name}'})"
342363
github_link = f"[Github]({project_page_link})"
343364
avg_pass_rate = sum(evaluate_numbers) / len(evaluate_numbers)
344-
leaderboard[split].append((avg_pass_rate * 100,
345-
f"\n|{display_name}|"
346-
f"{repos_resolved}|"
347-
f"{avg_pass_rate*100:.2f}%|"
348-
f"{total_duration:.2f}|"
349-
f"{submission_date}|"
350-
f"{analysis_link}|"
351-
f"{github_link}|"
352-
))
353-
if ((split == "all") and ("Reference (Gold)" not in display_name)):
354-
avg_lite_pass_rate = sum(lite_evaluate_numbers) / len(lite_evaluate_numbers)
355-
leaderboard["lite"].append((avg_lite_pass_rate * 100,
356-
f"\n|{display_name} (subset of `all`)|"
357-
f"{lite_repos_resolved}|"
358-
f"{avg_lite_pass_rate*100:.2f}%|"
359-
f"{lite_total_duration:.2f}|"
365+
leaderboard[split].append(
366+
(
367+
avg_pass_rate * 100,
368+
f"\n|{display_name}|"
369+
f"{repos_resolved}|"
370+
f"{avg_pass_rate*100:.2f}%|"
371+
f"{total_duration:.2f}|"
360372
f"{submission_date}|"
361373
f"{analysis_link}|"
362-
f"{github_link}|"
363-
))
374+
f"{github_link}|",
375+
)
376+
)
377+
if (split == "all") and ("Reference (Gold)" not in display_name):
378+
avg_lite_pass_rate = sum(lite_evaluate_numbers) / len(
379+
lite_evaluate_numbers
380+
)
381+
leaderboard["lite"].append(
382+
(
383+
avg_lite_pass_rate * 100,
384+
f"\n|{display_name} (subset of `all`)|"
385+
f"{lite_repos_resolved}|"
386+
f"{avg_lite_pass_rate*100:.2f}%|"
387+
f"{lite_total_duration:.2f}|"
388+
f"{submission_date}|"
389+
f"{analysis_link}|"
390+
f"{github_link}|",
391+
)
392+
)
364393

365394
leaderboard_filepath = os.path.join(subfolder, "analysis.md")
366395
for split in ["lite", "all"]:
@@ -371,7 +400,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
371400
wf.write(lite_leaderboard_string + "\n\n" + all_leaderboard_string)
372401

373402

374-
def get_args():
403+
def get_args() -> argparse.Namespace:
375404
parser = argparse.ArgumentParser()
376405
parser.add_argument(
377406
"--do_setup", action="store_true", help="Run commit0 setup with specified split"
@@ -400,14 +429,14 @@ def get_args():
400429
parser.add_argument(
401430
"--overwrite_previous_eval",
402431
action="store_true",
403-
help="Overwrite cached pytest info"
432+
help="Overwrite cached pytest info",
404433
# TODO add finer granularity so can specify which ones to overwrite
405434
)
406435

407436
return parser.parse_args()
408437

409438

410-
def main(args):
439+
def main(args: argparse.Namespace) -> NoReturn:
411440
global analysis_files_path
412441

413442
commit0_dataset_name = "wentingzhao/commit0_combined"
@@ -565,7 +594,7 @@ def main(args):
565594
)
566595
# run pytests
567596
os.system(
568-
f"commit0 evaluate --branch {branch_name} --timeout 1800"
597+
f"commit0 evaluate --branch {branch_name} --timeout 1800 "
569598
f"--commit0-config-file {commit0_dot_file_path}"
570599
)
571600
for example in dataset:

0 commit comments

Comments
 (0)