4
4
and HTTP servers.
5
5
"""
6
6
7
- from abc import ABC
7
+ from abc import ABC , abstractmethod
8
8
import docker
9
9
import logging
10
10
import os
11
11
import modal
12
12
from pathlib import Path
13
+ from typing import Optional , Type
14
+ from types import TracebackType
13
15
14
16
from commit0 .harness .spec import Spec
17
+ from commit0 .harness .utils import (
18
+ EvaluationError ,
19
+ )
15
20
from commit0 .harness .docker_build import (
16
21
close_logger ,
17
22
)
27
32
28
33
29
34
def read_stream (stream : modal .io_streams .StreamReader ) -> str :
35
+ """Read stream"""
30
36
strings = []
31
37
for line in stream :
32
38
strings .append (line )
@@ -53,33 +59,46 @@ def __init__(
53
59
self .timeout = timeout
54
60
self .log_dir = log_dir
55
61
62
+ @abstractmethod
56
63
def copy_ssh_pubkey_from_remote (self ) -> None :
64
+ """Copy"""
57
65
raise NotImplementedError
58
66
67
+ @abstractmethod
59
68
def copy_to_remote (self , local_path : Path , remote_path : Path ) -> None :
69
+ """Copy"""
60
70
raise NotImplementedError
61
71
72
+ @abstractmethod
62
73
def exec_run_with_timeout (self , command : str , timeout : int ) -> None :
74
+ """Exec"""
63
75
raise NotImplementedError
64
76
77
+ @abstractmethod
65
78
def exec_run (self , command : str ) -> None :
79
+ """Exec"""
66
80
raise NotImplementedError
67
81
82
+ @abstractmethod
68
83
def copy_from_remote (self , remote_path : Path , local_path : Path ) -> None :
84
+ """Copy"""
69
85
raise NotImplementedError
70
86
87
+ @abstractmethod
71
88
def delete_file_from_remote (self , remote_path : Path ) -> None :
89
+ """Delete"""
72
90
raise NotImplementedError
73
91
74
- def write_test_output (self , test_output , timed_out ):
92
+ def write_test_output (self , test_output : str , timed_out : bool ) -> None :
93
+ """Write test output"""
75
94
test_output_path = self .log_dir / "test_output.txt"
76
95
with open (test_output_path , "w" ) as f :
77
96
f .write (test_output )
78
97
if timed_out :
79
- f .write (f"\n \n Timeout error: { timeout } seconds exceeded." )
98
+ f .write (f"\n \n Timeout error: { self . timeout } seconds exceeded." )
80
99
raise EvaluationError (
81
100
self .spec .repo ,
82
- f"Test timed out after { timeout } seconds." ,
101
+ f"Test timed out after { self . timeout } seconds." ,
83
102
self .logger ,
84
103
)
85
104
@@ -95,7 +114,13 @@ def write_test_output(self, test_output, timed_out):
95
114
def __enter__ (self ):
96
115
return self
97
116
98
- def __exit__ (self , exc_type , exc_value , exc_traceback ):
117
+ @abstractmethod
118
+ def __exit__ (
119
+ self ,
120
+ exctype : Optional [Type [BaseException ]],
121
+ excinst : Optional [BaseException ],
122
+ exctb : Optional [TracebackType ],
123
+ ) -> bool :
99
124
raise NotImplementedError
100
125
101
126
@@ -121,26 +146,36 @@ def __init__(
121
146
self .copy_ssh_pubkey_from_remote ()
122
147
copy_to_container (self .container , eval_file , Path ("/eval.sh" ))
123
148
124
-
125
149
def copy_ssh_pubkey_from_remote (self ) -> None :
150
+ """Copy"""
126
151
copy_ssh_pubkey_from_container (self .container )
127
152
128
153
def copy_to_remote (self , local_file : Path , remote_path : Path ) -> None :
154
+ """Copy"""
129
155
copy_to_container (self .container , local_file , remote_path )
130
156
131
157
def exec_run_with_timeout (self , command : str , timeout : int ) -> ():
158
+ """Exec"""
132
159
return exec_run_with_timeout (self .container , command , timeout )
133
160
134
161
def exec_run (self , command : str ) -> None :
162
+ """Exec"""
135
163
return self .container .exec_run (command , demux = True )
136
164
137
165
def copy_from_remote (self , remote_path : Path , local_path : Path ) -> None :
166
+ """Copy"""
138
167
copy_from_container (self .container , remote_path , local_path )
139
168
140
169
def delete_file_from_remote (self , remote_path : Path ) -> None :
170
+ """Delete"""
141
171
delete_file_from_container (self .container , str (remote_path ))
142
172
143
- def __exit__ (self , exc_type , exc_value , exc_traceback ):
173
+ def __exit__ (
174
+ self ,
175
+ exctype : Optional [Type [BaseException ]],
176
+ excinst : Optional [BaseException ],
177
+ exctb : Optional [TracebackType ],
178
+ ) -> bool :
144
179
cleanup_container (self .client , self .container , self .logger )
145
180
close_logger (self .logger )
146
181
@@ -159,22 +194,22 @@ def __init__(
159
194
# the image must exist on dockerhub
160
195
reponame = spec .repo .split ("/" )[- 1 ]
161
196
image_name = f"wentingzhao/{ reponame } "
162
- image = (
163
- modal .Image .from_registry (image_name )
164
- .copy_local_file (eval_file , "/eval.sh" )
197
+ image = modal .Image .from_registry (image_name ).copy_local_file (
198
+ eval_file , "/eval.sh"
165
199
)
166
200
167
201
self .sandbox = modal .Sandbox .create (
168
202
"sleep" ,
169
203
"infinity" ,
170
204
image = image ,
171
205
cpu = 4.0 ,
172
- timeout = 300 ,
206
+ timeout = timeout ,
173
207
)
174
208
175
209
self .copy_ssh_pubkey_from_remote ()
176
210
177
- def copy_ssh_pubkey_from_remote (self ):
211
+ def copy_ssh_pubkey_from_remote (self ) -> None :
212
+ """Copy ssh pubkey"""
178
213
process = self .sandbox .exec ("bash" , "-c" , "cat /root/.ssh/id_rsa.pub" )
179
214
public_key = "" .join ([line for line in process .stdout ]).strip ()
180
215
@@ -197,6 +232,7 @@ def copy_ssh_pubkey_from_remote(self):
197
232
authorized_keys_file .write (public_key + "\n " )
198
233
199
234
def copy_to_remote (self , local_path : Path , remote_path : Path ) -> None :
235
+ """Copy"""
200
236
tempname = "tmpfile"
201
237
with local_path .open ("rb" ) as f :
202
238
self .nfs .write_file (tempname , f )
@@ -229,9 +265,15 @@ def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
229
265
with local_path .open ("w" ) as f :
230
266
f .write (output )
231
267
232
- def delete_file_from_remote (src , remote_path ):
268
+ def delete_file_from_remote (self , remote_path : Path ) -> None :
269
+ """Delete"""
233
270
self .sandbox .exec ("bash" , "-c" , f"rm { str (remote_path )} " )
234
271
235
- def __exit__ (self , exc_type , exc_value , exc_traceback ):
272
+ def __exit__ (
273
+ self ,
274
+ exctype : Optional [Type [BaseException ]],
275
+ excinst : Optional [BaseException ],
276
+ exctb : Optional [TracebackType ],
277
+ ) -> bool :
236
278
self .sandbox .terminate ()
237
279
close_logger (self .logger )
0 commit comments