Skip to content

Commit 49c9333

Browse files
Merge remote-tracking branch 'origin/main' into modal-clean
2 parents 9e08b2c + e9cfc06 commit 49c9333

File tree

12 files changed

+359
-20
lines changed

12 files changed

+359
-20
lines changed

.github/workflows/system.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@ jobs:
2626
run: uv run commit0 get-tests simpy
2727
- name: Test
2828
run: uv run commit0 test-reference simpy tests/test_event.py::test_succeed
29+
- name: Evaluate
30+
run: uv run commit0 evaluate-reference simpy
31+
- name: Save
32+
env:
33+
GITHUB_TOKEN: ${{ secrets.MY_GITHUB_TOKEN }}
34+
run: |
35+
uv run commit0 save simpy test-save-commit0

commit0/__main__.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import commit0.harness.get_pytest_ids
33
import commit0.harness.build
44
import commit0.harness.setup
5+
import commit0.harness.evaluate
6+
import commit0.harness.save
57
import copy
68
import sys
79
import os
@@ -19,7 +21,7 @@ def main() -> None:
1921
)
2022
# type check config values
2123
cs = ConfigStore.instance()
22-
cs.store(name="user", node=Commit0Config)
24+
cs.store(name="user", group="Commit0Config", node=Commit0Config)
2325
# have hydra to ignore all command-line arguments
2426
sys_argv = copy.deepcopy(sys.argv)
2527
sys.argv = [sys.argv[0]]
@@ -28,8 +30,8 @@ def main() -> None:
2830
# after hydra gets all configs, put command-line arguments back
2931
sys.argv = sys_argv
3032
# repo_split: split from command line has a higher priority than split in hydra
31-
if command in ["clone", "build"]:
32-
if len(sys.argv) == 3:
33+
if command in ["clone", "build", "evaluate", "evaluate-reference", "save"]:
34+
if len(sys.argv) >= 3:
3335
if sys.argv[2] not in SPLIT:
3436
raise ValueError(
3537
f"repo split must be from {', '.join(SPLIT.keys())}, but you provided {sys.argv[2]}"
@@ -43,6 +45,7 @@ def main() -> None:
4345
config.dataset_split,
4446
config.repo_split,
4547
config.base_dir,
48+
config.branch,
4649
)
4750
elif command == "build":
4851
commit0.harness.build.main(
@@ -53,7 +56,7 @@ def main() -> None:
5356
)
5457
elif command == "get-tests":
5558
repo = sys.argv[2]
56-
commit0.harness.get_pytest_ids.main(repo)
59+
commit0.harness.get_pytest_ids.main(repo, stdout=True)
5760
elif command == "test" or command == "test-reference":
5861
repo = sys.argv[2]
5962
test_ids = sys.argv[3]
@@ -68,6 +71,31 @@ def main() -> None:
6871
test_ids,
6972
config.backend,
7073
config.timeout,
74+
stdout=True,
75+
)
76+
elif command == "evaluate" or command == "evaluate-reference":
77+
if command == "evaluate-reference":
78+
config.branch = "reference"
79+
commit0.harness.evaluate.main(
80+
config.dataset_name,
81+
config.dataset_split,
82+
config.repo_split,
83+
config.base_dir,
84+
config.branch,
85+
config.backend,
86+
config.timeout,
87+
config.num_workers,
88+
)
89+
elif command == "save":
90+
organization = sys.argv[3]
91+
commit0.harness.save.main(
92+
config.dataset_name,
93+
config.dataset_split,
94+
config.repo_split,
95+
config.base_dir,
96+
organization,
97+
config.branch,
98+
config.github_token,
7199
)
72100

73101

commit0/configs/base.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ num_workers: 8
1616
backend: local
1717
branch: ai
1818
timeout: 1_800
19+
20+
# save related
21+
github_token: null

commit0/configs/config_class.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34

45
@dataclass
@@ -21,3 +22,6 @@ class Commit0Config:
2122
branch: str
2223
# timeout for running pytest
2324
timeout: int
25+
26+
# save related
27+
github_token: Optional[str]

commit0/harness/constants.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,16 @@ class RepoInstance(TypedDict):
2626
EVAL_BACKENDS = ["local", "modal"]
2727

2828
# available commands
29-
COMMANDS = ["clone", "build", "test", "test-reference", "get-tests"]
29+
COMMANDS = [
30+
"clone",
31+
"build",
32+
"test",
33+
"test-reference",
34+
"get-tests",
35+
"evaluate",
36+
"evaluate-reference",
37+
"save",
38+
]
3039
# repo splits
3140
SPLIT_MINITORCH = ["minitorch"]
3241
SPLIT_SIMPY = ["simpy"]
@@ -80,7 +89,8 @@ class RepoInstance(TypedDict):
8089
"mimesis",
8190
"babel",
8291
"dnspython",
83-
"portalocker," "cookiecutter",
92+
"portalocker",
93+
"cookiecutter",
8494
"pyjwt",
8595
"python-rsa",
8696
"more-itertools",

commit0/harness/evaluate.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import logging
2+
import os
3+
import traceback
4+
from collections import Counter
5+
6+
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from datasets import load_dataset
8+
from tqdm import tqdm
9+
from typing import Iterator
10+
11+
from commit0.harness.run_pytest_ids import main as run_tests
12+
from commit0.harness.get_pytest_ids import main as get_tests
13+
from commit0.harness.constants import RepoInstance, SPLIT
14+
15+
logging.basicConfig(
16+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17+
)
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def main(
22+
dataset_name: str,
23+
dataset_split: str,
24+
repo_split: str,
25+
base_dir: str,
26+
branch: str,
27+
backend: str,
28+
timeout: int,
29+
num_workers: int,
30+
) -> None:
31+
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
32+
repos = SPLIT[repo_split]
33+
pairs = []
34+
for example in dataset:
35+
repo_name = example["repo"].split("/")[-1]
36+
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
37+
continue
38+
pairs.append((repo_name, example["test"]["test_dir"]))
39+
40+
log_dirs = []
41+
with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar:
42+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
43+
# Create a future for running each instance
44+
futures = {
45+
executor.submit(
46+
run_tests,
47+
dataset_name,
48+
dataset_split,
49+
base_dir,
50+
repo,
51+
branch,
52+
test_dir,
53+
backend,
54+
timeout,
55+
stdout=False,
56+
): None
57+
for repo, test_dir in pairs
58+
}
59+
# Wait for each future to complete
60+
for future in as_completed(futures):
61+
pbar.update(1)
62+
try:
63+
# Update progress bar, check if instance ran successfully
64+
result = future.result()
65+
log_dirs.append(result)
66+
except Exception:
67+
traceback.print_exc()
68+
continue
69+
70+
# get numbers
71+
out = []
72+
for name in tqdm(log_dirs):
73+
report_file = os.path.join(name, "report.json")
74+
name = name.split("/")[2]
75+
if not os.path.exists(report_file):
76+
out.append(
77+
{
78+
"name": name,
79+
"sum": 0,
80+
"passed": 0,
81+
"num_passed": 0,
82+
}
83+
)
84+
continue
85+
report = load_dataset("json", data_files=report_file, split="train") # type: ignore
86+
test_ids = get_tests(name, stdout=False)
87+
tests = {x["nodeid"]: x["call"] for x in report["tests"][0]} # type: ignore
88+
status = []
89+
runtimes = []
90+
no_runs = 0
91+
for test_id in test_ids:
92+
if test_id in tests and tests[test_id] is not None:
93+
status.append(tests[test_id]["outcome"])
94+
runtimes.append(tests[test_id]["duration"])
95+
no_runs += 1
96+
else:
97+
status.append("failed")
98+
runtimes.append(0)
99+
status = Counter(status)
100+
if no_runs == 0:
101+
total = 0
102+
else:
103+
total = sum(runtimes)
104+
if "xfail" not in status:
105+
status["xfail"] = 0
106+
passed = (status["passed"] + status["xfail"]) / sum(status.values())
107+
out.append(
108+
{
109+
"name": name,
110+
"sum": total,
111+
"passed": passed,
112+
"num_passed": status["passed"] + status["xfail"],
113+
"num_tests": sum(status.values()),
114+
}
115+
)
116+
print("repo,runtime,num_passed/num_tests")
117+
out = sorted(out, key=lambda x: x["sum"], reverse=True)
118+
for x in out:
119+
print(f"{x['name']},{x['sum']},{x['num_passed']}/{x['num_tests']}")
120+
total_runtime = sum([x["sum"] for x in out])
121+
averaged_passed = sum([x["passed"] for x in out]) / len(out)
122+
print(f"total runtime: {total_runtime}")
123+
print(f"average pass rate: {averaged_passed}")
124+
125+
126+
__all__ = []

commit0/harness/get_pytest_ids.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import tarfile
2+
from typing import List
23

34

4-
def main(repo: str) -> None:
5+
def main(repo: str, stdout: bool) -> List[str]:
56
repo = repo.lower()
67
repo = repo.replace(".", "-")
8+
out = ""
79
with tarfile.open(f"commit0/data/test_ids/{repo}.tar.bz2", "r:bz2") as tar:
810
for member in tar.getmembers():
911
if member.isfile():
1012
file = tar.extractfile(member)
1113
if file:
12-
content = file.read()
13-
print(content.decode("utf-8"))
14+
content = file.read().decode("utf-8")
15+
out += content
16+
if stdout:
17+
print(content)
18+
out = out.split("\n")
19+
return out
1420

1521

1622
__all__ = []

commit0/harness/run_pytest_ids.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ class ExecutionBackend(StrEnum):
4141

4242

4343
def run_docker(
44-
spec: Spec, logger: logging.Logger, eval_file: Path, timeout: int, log_dir: Path
44+
spec: Spec,
45+
logger: logging.Logger,
46+
eval_file: Path,
47+
timeout: int,
48+
log_dir: Path,
49+
stdout: bool,
4550
) -> None:
4651
"""Runs the tests in a local docker container.
4752
@@ -76,7 +81,8 @@ def run_docker(
7681
output, "--json-report --json-report-file=report.json"
7782
)
7883
# stdout might be more straightforward
79-
print(test_output)
84+
if stdout:
85+
print(test_output)
8086
test_output_path = log_dir / "test_output.txt"
8187
with open(test_output_path, "w") as f:
8288
f.write(test_output)
@@ -116,7 +122,12 @@ def run_docker(
116122

117123

118124
def run_modal(
119-
spec: Spec, logger: logging.Logger, eval_file: Path, timeout: int, log_dir: Path
125+
spec: Spec,
126+
logger: logging.Logger,
127+
eval_file: Path,
128+
timeout: int,
129+
log_dir: Path,
130+
stdout: bool,
120131
) -> None:
121132
"""Runs the tests in a remote Modal container.
122133
@@ -156,7 +167,32 @@ def run_modal(
156167
# TODO: add timeout
157168
print(output)
158169
print(error)
159-
return
170+
171+
output = []
172+
for line in process.stderr:
173+
output.append(line)
174+
output_s = "".join(line)
175+
logger.info(output_s)
176+
print(output_s)
177+
178+
timed_out = False
179+
test_output = extract_test_output(
180+
output_s, "--json-report --json-report-file=report.json"
181+
)
182+
183+
# stdout might be more straightforward
184+
if stdout:
185+
print(test_output)
186+
test_output_path = log_dir / "test_output.txt"
187+
with open(test_output_path, "w") as f:
188+
f.write(test_output)
189+
if timed_out:
190+
f.write(f"\n\nTimeout error: {timeout} seconds exceeded.")
191+
raise EvaluationError(
192+
spec.repo,
193+
f"Test timed out after {timeout} seconds.",
194+
logger,
195+
)
160196

161197

162198
def main(
@@ -168,7 +204,8 @@ def main(
168204
test_ids: str,
169205
backend: str,
170206
timeout: int,
171-
) -> None:
207+
stdout: bool,
208+
) -> str:
172209
"""Runs the pytests for repos in a dataset.
173210
174211
Tests are run either locally through docker
@@ -186,7 +223,7 @@ def main(
186223

187224
hashed_test_ids = get_hash_string(test_ids)
188225
# set up logging
189-
log_dir = RUN_PYTEST_LOG_DIR / repo / hashed_test_ids
226+
log_dir = RUN_PYTEST_LOG_DIR / repo / branch / hashed_test_ids
190227
log_dir.mkdir(parents=True, exist_ok=True)
191228
log_file = log_dir / "run_pytest.log"
192229
logger = setup_logger(repo, log_file)
@@ -262,9 +299,11 @@ def main(
262299

263300
"""
264301
if ExecutionBackend(backend) == ExecutionBackend.LOCAL:
265-
run_docker(spec, logger, eval_file, timeout, log_dir)
302+
run_docker(spec, logger, eval_file, timeout, log_dir, stdout)
266303
elif ExecutionBackend(backend) == ExecutionBackend.MODAL:
304+
run_modal(spec, logger, eval_file, timeout, log_dir, stdout)
267305
"""
306+
return str(log_dir)
268307

269308

270309
__all__ = []

0 commit comments

Comments
 (0)