Skip to content

Commit c03b187

Browse files
pyright
1 parent 1c23c1b commit c03b187

File tree

2 files changed

+57
-19
lines changed

2 files changed

+57
-19
lines changed

commit0/harness/execution_context.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
and HTTP servers.
55
"""
66

7-
from abc import ABC
7+
from abc import ABC, abstractmethod
88
import docker
99
import logging
1010
import os
1111
import modal
1212
from pathlib import Path
13+
from typing import Optional, Type
14+
from types import TracebackType
1315

1416
from commit0.harness.spec import Spec
17+
from commit0.harness.utils import (
18+
EvaluationError,
19+
)
1520
from commit0.harness.docker_build import (
1621
close_logger,
1722
)
@@ -27,6 +32,7 @@
2732

2833

2934
def read_stream(stream: modal.io_streams.StreamReader) -> str:
35+
"""Read stream"""
3036
strings = []
3137
for line in stream:
3238
strings.append(line)
@@ -53,33 +59,46 @@ def __init__(
5359
self.timeout = timeout
5460
self.log_dir = log_dir
5561

62+
@abstractmethod
5663
def copy_ssh_pubkey_from_remote(self) -> None:
64+
"""Copy"""
5765
raise NotImplementedError
5866

67+
@abstractmethod
5968
def copy_to_remote(self, local_path: Path, remote_path: Path) -> None:
69+
"""Copy"""
6070
raise NotImplementedError
6171

72+
@abstractmethod
6273
def exec_run_with_timeout(self, command: str, timeout: int) -> None:
74+
"""Exec"""
6375
raise NotImplementedError
6476

77+
@abstractmethod
6578
def exec_run(self, command: str) -> None:
79+
"""Exec"""
6680
raise NotImplementedError
6781

82+
@abstractmethod
6883
def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
84+
"""Copy"""
6985
raise NotImplementedError
7086

87+
@abstractmethod
7188
def delete_file_from_remote(self, remote_path: Path) -> None:
89+
"""Delete"""
7290
raise NotImplementedError
7391

74-
def write_test_output(self, test_output, timed_out):
92+
def write_test_output(self, test_output: str, timed_out: bool) -> None:
93+
"""Write test output"""
7594
test_output_path = self.log_dir / "test_output.txt"
7695
with open(test_output_path, "w") as f:
7796
f.write(test_output)
7897
if timed_out:
79-
f.write(f"\n\nTimeout error: {timeout} seconds exceeded.")
98+
f.write(f"\n\nTimeout error: {self.timeout} seconds exceeded.")
8099
raise EvaluationError(
81100
self.spec.repo,
82-
f"Test timed out after {timeout} seconds.",
101+
f"Test timed out after {self.timeout} seconds.",
83102
self.logger,
84103
)
85104

@@ -95,7 +114,13 @@ def write_test_output(self, test_output, timed_out):
95114
def __enter__(self):
96115
return self
97116

98-
def __exit__(self, exc_type, exc_value, exc_traceback):
117+
@abstractmethod
118+
def __exit__(
119+
self,
120+
exctype: Optional[Type[BaseException]],
121+
excinst: Optional[BaseException],
122+
exctb: Optional[TracebackType],
123+
) -> bool:
99124
raise NotImplementedError
100125

101126

@@ -121,26 +146,36 @@ def __init__(
121146
self.copy_ssh_pubkey_from_remote()
122147
copy_to_container(self.container, eval_file, Path("/eval.sh"))
123148

124-
125149
def copy_ssh_pubkey_from_remote(self) -> None:
150+
"""Copy"""
126151
copy_ssh_pubkey_from_container(self.container)
127152

128153
def copy_to_remote(self, local_file: Path, remote_path: Path) -> None:
154+
"""Copy"""
129155
copy_to_container(self.container, local_file, remote_path)
130156

131157
def exec_run_with_timeout(self, command: str, timeout: int) -> ():
158+
"""Exec"""
132159
return exec_run_with_timeout(self.container, command, timeout)
133160

134161
def exec_run(self, command: str) -> None:
162+
"""Exec"""
135163
return self.container.exec_run(command, demux=True)
136164

137165
def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
166+
"""Copy"""
138167
copy_from_container(self.container, remote_path, local_path)
139168

140169
def delete_file_from_remote(self, remote_path: Path) -> None:
170+
"""Delete"""
141171
delete_file_from_container(self.container, str(remote_path))
142172

143-
def __exit__(self, exc_type, exc_value, exc_traceback):
173+
def __exit__(
174+
self,
175+
exctype: Optional[Type[BaseException]],
176+
excinst: Optional[BaseException],
177+
exctb: Optional[TracebackType],
178+
) -> bool:
144179
cleanup_container(self.client, self.container, self.logger)
145180
close_logger(self.logger)
146181

@@ -159,22 +194,22 @@ def __init__(
159194
# the image must exist on dockerhub
160195
reponame = spec.repo.split("/")[-1]
161196
image_name = f"wentingzhao/{reponame}"
162-
image = (
163-
modal.Image.from_registry(image_name)
164-
.copy_local_file(eval_file, "/eval.sh")
197+
image = modal.Image.from_registry(image_name).copy_local_file(
198+
eval_file, "/eval.sh"
165199
)
166200

167201
self.sandbox = modal.Sandbox.create(
168202
"sleep",
169203
"infinity",
170204
image=image,
171205
cpu=4.0,
172-
timeout=300,
206+
timeout=timeout,
173207
)
174208

175209
self.copy_ssh_pubkey_from_remote()
176210

177-
def copy_ssh_pubkey_from_remote(self):
211+
def copy_ssh_pubkey_from_remote(self) -> None:
212+
"""Copy ssh pubkey"""
178213
process = self.sandbox.exec("bash", "-c", "cat /root/.ssh/id_rsa.pub")
179214
public_key = "".join([line for line in process.stdout]).strip()
180215

@@ -197,6 +232,7 @@ def copy_ssh_pubkey_from_remote(self):
197232
authorized_keys_file.write(public_key + "\n")
198233

199234
def copy_to_remote(self, local_path: Path, remote_path: Path) -> None:
235+
"""Copy"""
200236
tempname = "tmpfile"
201237
with local_path.open("rb") as f:
202238
self.nfs.write_file(tempname, f)
@@ -229,9 +265,15 @@ def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
229265
with local_path.open("w") as f:
230266
f.write(output)
231267

232-
def delete_file_from_remote(src, remote_path):
268+
def delete_file_from_remote(self, remote_path: Path) -> None:
269+
"""Delete"""
233270
self.sandbox.exec("bash", "-c", f"rm {str(remote_path)}")
234271

235-
def __exit__(self, exc_type, exc_value, exc_traceback):
272+
def __exit__(
273+
self,
274+
exctype: Optional[Type[BaseException]],
275+
excinst: Optional[BaseException],
276+
exctb: Optional[TracebackType],
277+
) -> bool:
236278
self.sandbox.terminate()
237279
close_logger(self.logger)

commit0/harness/run_pytest_ids.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
from datasets import load_dataset
2-
import docker
32
from enum import StrEnum, auto
43
import traceback
54
from pathlib import Path
6-
import logging
75

86
from typing import Iterator
97
from git import Repo
108
from commit0.harness.constants import RUN_PYTEST_LOG_DIR, RepoInstance
119
from commit0.harness.docker_build import (
12-
close_logger,
1310
setup_logger,
1411
)
15-
from commit0.harness.spec import Spec, make_spec
12+
from commit0.harness.spec import make_spec
1613
from commit0.harness.utils import (
1714
EvaluationError,
1815
extract_test_output,
@@ -109,7 +106,6 @@ def main(
109106
# f"Check ({logger.log_file}) for more information."
110107
)
111108
logger.error(error_msg)
112-
113109

114110
return str(log_dir)
115111

0 commit comments

Comments
 (0)