Skip to content

Commit 4a33937

Browse files
committed
fix type errors
1 parent 38205e6 commit 4a33937

File tree

5 files changed

+63
-24
lines changed

5 files changed

+63
-24
lines changed

commit0/cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def setup(
118118
) -> None:
119119
"""Commit0 clone a repo split."""
120120
check_commit0_path()
121-
if "commit0" in dataset_name.split('/')[-1].lower():
121+
if "commit0" in dataset_name.split("/")[-1].lower():
122122
check_valid(repo_split, SPLIT)
123123

124124
base_dir = str(Path(base_dir).resolve())
@@ -169,7 +169,7 @@ def build(
169169
check_commit0_path()
170170

171171
commit0_config = read_commit0_config_file(commit0_config_file)
172-
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
172+
if "commit0" in commit0_config["dataset_name"].split("/")[-1].lower():
173173
check_valid(commit0_config["repo_split"], SPLIT)
174174

175175
typer.echo(
@@ -251,13 +251,13 @@ def test(
251251
commit0_config = read_commit0_config_file(commit0_config_file)
252252
if repo_or_repo_path.endswith("/"):
253253
repo_or_repo_path = repo_or_repo_path[:-1]
254-
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
254+
if "commit0" in commit0_config["dataset_name"].split("/")[-1].lower():
255255
check_valid(repo_or_repo_path.split("/")[-1], SPLIT)
256256

257257
if reference:
258258
branch = "reference"
259259
else:
260-
if "humaneval" not in commit0_config["dataset_name"].split('/')[-1].lower():
260+
if "humaneval" not in commit0_config["dataset_name"].split("/")[-1].lower():
261261
if branch is None and not reference:
262262
git_path = os.path.join(
263263
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
@@ -321,7 +321,7 @@ def evaluate(
321321
branch = "reference"
322322

323323
commit0_config = read_commit0_config_file(commit0_config_file)
324-
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
324+
if "commit0" in commit0_config["dataset_name"].split("/")[-1].lower():
325325
check_valid(commit0_config["repo_split"], SPLIT)
326326

327327
typer.echo(f"Evaluating repository split: {commit0_config['repo_split']}")
@@ -397,7 +397,7 @@ def save(
397397
"""Save Commit0 split you choose in Setup Stage to GitHub."""
398398
check_commit0_path()
399399
commit0_config = read_commit0_config_file(commit0_config_file)
400-
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
400+
if "commit0" in commit0_config["dataset_name"].split("/")[-1].lower():
401401
check_valid(commit0_config["repo_split"], SPLIT)
402402

403403
typer.echo(f"Saving repository split: {commit0_config['repo_split']}")

commit0/harness/build.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import docker
44
from datasets import load_dataset
5-
from typing import Iterator
5+
from typing import Iterator, Union
66

77
from commit0.harness.constants import RepoInstance, SimpleInstance, SPLIT
88
from commit0.harness.docker_build import build_repo_images
@@ -21,7 +21,9 @@ def main(
2121
num_workers: int,
2222
verbose: int,
2323
) -> None:
24-
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(dataset_name, split=dataset_split) # type: ignore
24+
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(
25+
dataset_name, split=dataset_split
26+
) # type: ignore
2527
specs = []
2628
if "swe" in dataset_name.lower():
2729
dataset_type = "swebench"

commit0/harness/constants.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from enum import Enum
22
from pathlib import Path
3-
from typing import Dict, TypedDict
3+
from typing import Dict, ItemsView
4+
from pydantic import BaseModel
45

56

6-
class RepoInstance(TypedDict):
7+
class RepoInstance(BaseModel):
78
instance_id: str
89
repo: str
910
base_commit: str
@@ -12,19 +13,34 @@ class RepoInstance(TypedDict):
1213
test: Dict[str, str]
1314
src_dir: str
1415

16+
def __getitem__(self, item: str):
17+
return getattr(self, item)
1518

16-
class SimpleInstance(TypedDict):
19+
20+
class SimpleInstance(BaseModel):
1721
instance_id: str
1822
prompt: str
1923
canonical_solution: str
2024
test: str
2125
entry_point: str
2226

27+
def __getitem__(self, item: str):
28+
return getattr(self, item)
29+
2330

24-
class Files(TypedDict):
31+
class Files(BaseModel):
2532
eval_script: Dict[str, Path]
2633
patch: Dict[str, Path]
2734

35+
def __getitem__(self, item: str):
36+
return getattr(self, item)
37+
38+
def items(self) -> ItemsView[str, object]:
39+
"""Using self.dict() to obtain the underlying data as a dictionary,
40+
which is then iterated to yield key-value pairs.
41+
"""
42+
return self.dict().items()
43+
2844

2945
BASE_BRANCH = "commit0"
3046

commit0/harness/run_pytest_ids.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datasets import load_dataset
77
from pathlib import Path
88

9-
from typing import Iterator
9+
from typing import Iterator, Union
1010
from commit0.harness.constants import (
1111
EVAL_BACKENDS,
1212
Files,
@@ -48,10 +48,13 @@ def main(
4848
Tests are run either locally through docker
4949
or remotely through Modal.
5050
"""
51-
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(dataset_name, split=dataset_split) # type: ignore
51+
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(
52+
dataset_name, split=dataset_split
53+
) # type: ignore
5254
spec = None
5355
example = None
5456
repo_name = None
57+
dataset_type = None
5558
for example in dataset:
5659
if repo_or_repo_dir.endswith("/"):
5760
repo_or_repo_dir = repo_or_repo_dir[:-1]
@@ -64,12 +67,15 @@ def main(
6467
else:
6568
repo_name = example["repo"].split("/")[-1]
6669
dataset_type = "commit0"
67-
if repo_name in os.path.basename(repo_or_repo_dir) or repo_or_repo_dir.endswith(repo_name):
70+
if repo_name in os.path.basename(repo_or_repo_dir) or repo_or_repo_dir.endswith(
71+
repo_name
72+
):
6873
spec = make_spec(example, dataset_type)
6974
break
7075
assert spec is not None, "No spec available"
7176
assert example is not None, "No example available"
7277
assert repo_name is not None, "No repo available"
78+
assert dataset_type is not None, "No dataset_type available"
7379

7480
hashed_test_ids = get_hash_string(test_ids)
7581
# set up logging
@@ -78,13 +84,17 @@ def main(
7884
log_file = log_dir / "run_pytest.log"
7985
logger = setup_logger(repo_name, log_file, verbose=verbose)
8086

81-
if dataset_type != "simple": # if dataset_type is not simple, load git repo
87+
if not isinstance(
88+
example, SimpleInstance
89+
): # if dataset_type is not simple, load git repo
8290
try:
8391
local_repo = git.Repo(repo_or_repo_dir)
8492
logger.info(f"Loaded a git repo from {repo_or_repo_dir}")
8593
except (git.exc.NoSuchPathError, git.exc.InvalidGitRepositoryError): # type: ignore
8694
repo_dir = os.path.join(base_dir, repo_name)
87-
logger.error(f"{repo_or_repo_dir} is not a git dir, trying {repo_dir} again")
95+
logger.error(
96+
f"{repo_or_repo_dir} is not a git dir, trying {repo_dir} again"
97+
)
8898
try:
8999
local_repo = git.Repo(repo_dir)
90100
logger.info(f"Retried succeeded. Loaded a git repo from {repo_dir}")
@@ -117,10 +127,18 @@ def main(
117127
if found_remote_branch:
118128
break # Stop checking other remotes if branch is found
119129
if not found_remote_branch:
120-
raise Exception(f"Branch {branch} does not exist locally or remotely.")
121-
if dataset_type == "simple":
130+
raise Exception(
131+
f"Branch {branch} does not exist locally or remotely."
132+
)
133+
if isinstance(example, SimpleInstance):
122134
if branch == "reference":
123-
patch = example["prompt"] + "\n\n" + example["canonical_solution"] + "\n\n" + example["test"]
135+
patch = (
136+
example["prompt"]
137+
+ "\n\n"
138+
+ example["canonical_solution"]
139+
+ "\n\n"
140+
+ example["test"]
141+
)
124142
else:
125143
solution = open(test_ids).read()
126144
pattern = r"```python\n(.*?)```"
@@ -147,10 +165,12 @@ def main(
147165
patch_file = Path(log_dir / "patch.diff")
148166
patch_file.write_text(patch, encoding="utf-8", errors="ignore")
149167

150-
if dataset_type != "simple":
168+
if not isinstance(example, SimpleInstance):
151169
# make eval file
152170
if coverage:
153-
coverage_text = f" --cov={example['src_dir']} --cov-branch --cov-report json"
171+
coverage_text = (
172+
f" --cov={example['src_dir']} --cov-branch --cov-report json"
173+
)
154174
else:
155175
coverage_text = ""
156176
eval_script = spec.eval_script.format(test_ids=test_ids, coverage=coverage_text)

commit0/harness/spec.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def make_repo_script_list(self) -> list[str]:
183183
"""
184184
setup_commands = [
185185
f"mkdir {self.repo_directory} && cd {self.repo_directory}",
186-
f"uv venv --python 3.12",
186+
"uv venv --python 3.12",
187187
"source .venv/bin/activate",
188188
"which python",
189189
]
@@ -303,7 +303,8 @@ def make_eval_script_list(self) -> list[str]:
303303

304304

305305
def get_specs_from_dataset(
306-
dataset: Union[list[Union[RepoInstance, SimpleInstance]], list[Spec]], dataset_type: str
306+
dataset: Union[list[Union[RepoInstance, SimpleInstance]], list[Spec]],
307+
dataset_type: str,
307308
) -> list[Spec]:
308309
"""Idempotent function that converts a list of RepoInstance objects to a list of Spec objects."""
309310
if isinstance(dataset[0], Spec):

0 commit comments

Comments
 (0)