@@ -32,6 +32,7 @@ def read_stream(stream: modal.io_streams.StreamReader) -> str:
32
32
strings .append (line )
33
33
return "\n " .join (strings )
34
34
35
+
35
36
class ExecutionContext (ABC ):
36
37
def __init__ (
37
38
self ,
@@ -46,7 +47,11 @@ def __init__(
46
47
The execution context will persist for the lifetime of this object.
47
48
The execution context can be a Docker container or Modal sandbox.
48
49
"""
49
- raise NotImplementedError
50
+ self .spec = spec
51
+ self .logger = logger
52
+ self .eval_file = eval_file
53
+ self .timeout = timeout
54
+ self .log_dir = log_dir
50
55
51
56
def copy_ssh_pubkey_from_remote (self ) -> None :
52
57
raise NotImplementedError
@@ -60,14 +65,35 @@ def exec_run_with_timeout(self, command: str, timeout: int) -> None:
60
65
def exec_run (self , command : str ) -> None :
61
66
raise NotImplementedError
62
67
63
- def copy_from_remote (self , remote_path , local_path ) -> None :
68
+ def copy_from_remote (self , remote_path : Path , local_path : Path ) -> None :
64
69
raise NotImplementedError
65
70
66
- def delete_file_from_remote (self , remote_path ) -> None :
71
+ def delete_file_from_remote (self , remote_path : Path ) -> None :
67
72
raise NotImplementedError
68
73
74
+ def write_test_output (self , test_output , timed_out ):
75
+ test_output_path = self .log_dir / "test_output.txt"
76
+ with open (test_output_path , "w" ) as f :
77
+ f .write (test_output )
78
+ if timed_out :
79
+ f .write (f"\n \n Timeout error: { timeout } seconds exceeded." )
80
+ raise EvaluationError (
81
+ self .spec .repo ,
82
+ f"Test timed out after { timeout } seconds." ,
83
+ self .logger ,
84
+ )
85
+
86
+ # copy back report.json if there is any
87
+ report_file = Path (self .spec .repo_directory ) / "report.json"
88
+ # Run the test command inside the container to check if the file exists
89
+ exit_code , output = self .exec_run (f"test -e { report_file } " )
90
+ # Check the exit code of the command
91
+ if exit_code == 0 :
92
+ self .copy_from_remote (report_file , self .log_dir / "report.json" )
93
+ self .delete_file_from_remote (str (report_file ))
94
+
69
95
def __enter__ (self ):
70
- raise NotImplementedError
96
+ return self
71
97
72
98
def __exit__ (self , exc_type , exc_value , exc_traceback ):
73
99
raise NotImplementedError
@@ -82,8 +108,9 @@ def __init__(
82
108
timeout : int ,
83
109
log_dir : Path ,
84
110
):
111
+ super ().__init__ (spec , logger , eval_file , timeout , log_dir )
112
+
85
113
self .client = docker .from_env ()
86
- self .logger = logger
87
114
self .container = create_container (
88
115
client = self .client ,
89
116
image_name = spec .repo_image_key ,
@@ -92,6 +119,8 @@ def __init__(
92
119
)
93
120
self .container .start ()
94
121
self .copy_ssh_pubkey_from_remote ()
122
+ copy_to_container (self .container , eval_file , Path ("/eval.sh" ))
123
+
95
124
96
125
def copy_ssh_pubkey_from_remote (self ) -> None :
97
126
copy_ssh_pubkey_from_container (self .container )
@@ -111,9 +140,6 @@ def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
111
140
def delete_file_from_remote (self , remote_path : Path ) -> None :
112
141
delete_file_from_container (self .container , str (remote_path ))
113
142
114
- def __enter__ (self ):
115
- return self
116
-
117
143
def __exit__ (self , exc_type , exc_value , exc_traceback ):
118
144
cleanup_container (self .client , self .container , self .logger )
119
145
close_logger (self .logger )
@@ -128,22 +154,22 @@ def __init__(
128
154
timeout : int ,
129
155
log_dir : Path ,
130
156
):
131
- self .logger = logger
157
+ super ().__init_ (spec , logger , eval_file , timeout , log_dir )
158
+
132
159
# the image must exist on dockerhub
133
160
reponame = spec .repo .split ("/" )[- 1 ]
134
161
image_name = f"wentingzhao/{ reponame } "
135
- image = modal .Image .from_registry (image_name )
162
+ image = (
163
+ modal .Image .from_registry (image_name )
164
+ .copy_local_file (eval_file , "/eval.sh" )
165
+ )
136
166
137
- self .nfs = modal .NetworkFileSystem .ephemeral ().__enter__ ()
138
167
self .sandbox = modal .Sandbox .create (
139
168
"sleep" ,
140
169
"infinity" ,
141
170
image = image ,
142
- network_file_systems = {
143
- "/vol" : self .nfs ,
144
- },
145
- cpu = 8.0 ,
146
- timeout = 30 ,
171
+ cpu = 4.0 ,
172
+ timeout = 300 ,
147
173
)
148
174
149
175
self .copy_ssh_pubkey_from_remote ()
@@ -206,9 +232,6 @@ def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
206
232
def delete_file_from_remote (src , remote_path ):
207
233
self .sandbox .exec ("bash" , "-c" , f"rm { str (remote_path )} " )
208
234
209
- def __enter__ (self ):
210
- return self
211
-
212
235
def __exit__ (self , exc_type , exc_value , exc_traceback ):
213
- # self.nfs.__exit__ ()
214
- pass
236
+ self .sandbox . terminate ()
237
+ close_logger ( self . logger )
0 commit comments