44import uuid
55import heapq
66import botocore
7+ import shlex
78import time
89import threading
910from typing import List , Generator , Optional
2728# of None or anything else that makes sense in your case.
2829@dataclass
2930class ExecutorSettings (ExecutorSettingsBase ):
31+ access_key_id : Optional [int ] = field (
32+ default = None ,
33+ metadata = {"help" : "AWS access key id" , "env_var" : True , "required" : False },
34+ repr = False ,
35+ )
36+ access_key : Optional [int ] = field (
37+ default = None ,
38+ metadata = {"help" : "AWS access key" , "env_var" : True , "required" : False },
39+ repr = False ,
40+ )
3041 region : Optional [int ] = field (
3142 default = None ,
3243 metadata = {
@@ -58,23 +69,23 @@ class ExecutorSettings(ExecutorSettingsBase):
5869 },
5970 ),
6071 )
61- task_queue : Optional [str ] = field (
72+ job_queue : Optional [str ] = field (
6273 default = None ,
6374 metadata = {
6475 "help" : "The AWS Batch task queue ARN used for running tasks" ,
65- "env_var" : False ,
76+ "env_var" : True ,
6677 "required" : True ,
6778 },
6879 )
69- workflow_role : Optional [str ] = field (
80+ execution_role : Optional [str ] = field (
7081 default = None ,
7182 metadata = {
72- "help" : "The AWS role that is used for running the tasks" ,
73- "env_var" : False ,
83+ "help" : "The AWS execution role ARN that is used for running the tasks" ,
84+ "env_var" : True ,
7485 "required" : True ,
7586 },
7687 )
77- tags : Optional [List [ str ] ] = field (
88+ tags : Optional [dict ] = field (
7889 default = None ,
7990 metadata = {
8091 "help" : (
@@ -105,7 +116,9 @@ class ExecutorSettings(ExecutorSettingsBase):
105116 # plugins (snakemake-executor-plugin-dryrun, snakemake-executor-plugin-local)
106117 # are expected to specify False here.
107118 non_local_exec = True ,
119+ # whether the executor implies to not have a shared file system
108120 implies_no_shared_fs = True ,
121+ # whether to deploy workflow sources to default storage provider before execution
109122 job_deploy_sources = True ,
110123 # whether arguments for setting the storage provider shall be passed to jobs
111124 pass_default_storage_provider_args = True ,
@@ -136,15 +149,21 @@ def __post_init__(self):
136149
137150 # keep track of job definitions
138151 self .created_job_defs = list ()
152+ self .mount_path = None
153+ self ._describer = BatchJobDescriber ()
139154
140155 # init batch client
141156 try :
142- self .aws = boto3 .Session ().client ( # Session() needed for thread safety
143- "batch" ,
144- region_name = self .settings .region ,
145- config = botocore .config .Config (
146- retries = {"max_attempts" : 5 , "mode" : "standard" }
147- ),
157+ self .batch_client = (
158+ boto3 .Session ().client ( # Session() needed for thread safety
159+ "batch" ,
160+ aws_access_key_id = self .settings .access_key_id ,
161+ aws_secret_access_key = self .settings .access_key ,
162+ region_name = self .settings .region ,
163+ config = botocore .config .Config (
164+ retries = {"max_attempts" : 5 , "mode" : "standard" }
165+ ),
166+ )
148167 )
149168 except Exception as e :
150169 raise WorkflowError (e )
@@ -194,19 +213,25 @@ def run_job(self, job: JobExecutorInterface):
194213 job_definition_type = "container"
195214
196215 # get job resources or default
197- vcpu = job .resources .get ("_cores" , str (1 ))
198- mem = job .resources .get ("mem_mb" , str (1024 ))
216+ vcpu = str ( job .resources .get ("_cores" , str (1 ) ))
217+ mem = str ( job .resources .get ("mem_mb" , str (2048 ) ))
199218
200219 # job definition container properties
201220 container_properties = {
202- "executionRoleArn" : self .settings .workflow_role ,
203221 "command" : ["snakemake" ],
204222 "image" : self .container_image ,
205- "privileged" : True ,
223+ # fargate required privileged False
224+ "privileged" : False ,
206225 "resourceRequirements" : [
207- {"type" : "VCPU" , "value" : "1" },
208- {"type" : "MEMORY" , "value" : "1024" },
226+ # resource requirements have to be compatible
227+ # see: https://docs.aws.amazon.com/batch/latest/APIReference/API_ResourceRequirement.html # noqa
228+ {"type" : "VCPU" , "value" : vcpu },
229+ {"type" : "MEMORY" , "value" : mem },
209230 ],
231+ "networkConfiguration" : {
232+ "assignPublicIp" : "ENABLED" ,
233+ },
234+ "executionRoleArn" : self .settings .execution_role ,
210235 }
211236
212237 # TODO: or not todo ?
@@ -216,14 +241,16 @@ def run_job(self, job: JobExecutorInterface):
216241 # ) = self._prepare_mounts()
217242
218243 # register the job definition
244+ tags = self .settings .tags if isinstance (self .settings .tags , dict ) else dict ()
219245 try :
220- job_def = self .aws .register_job_definition (
246+ job_def = self .batch_client .register_job_definition (
221247 jobDefinitionName = job_definition_name ,
222248 type = job_definition_type ,
223249 containerProperties = container_properties ,
224- tags = self .settings .tags ,
250+ platformCapabilities = ["FARGATE" ],
251+ tags = tags ,
225252 )
226- self .created_job_defs .append (self . job_def )
253+ self .created_job_defs .append (job_def )
227254 except Exception as e :
228255 raise WorkflowError (e )
229256
@@ -232,7 +259,7 @@ def run_job(self, job: JobExecutorInterface):
232259 # configure job parameters
233260 job_params = {
234261 "jobName" : job_name ,
235- "jobQueue" : self .settings .task_queue ,
262+ "jobQueue" : self .settings .job_queue ,
236263 "jobDefinition" : "{}:{}" .format (
237264 job_def ["jobDefinitionName" ], job_def ["revision" ]
238265 ),
@@ -243,18 +270,22 @@ def run_job(self, job: JobExecutorInterface):
243270 {"type" : "MEMORY" , "value" : mem },
244271 ],
245272 },
246- "tags" : self .settings .tags ,
247273 }
248274
275+ if self .settings .tags :
276+ job_params ["tags" ] = self .settings .tags
277+
249278 if self .settings .task_timeout is not None :
250- job_params ["timeout" ] = {"attemptDurationSeconds" : self .task_timeout }
279+ job_params ["timeout" ] = {
280+ "attemptDurationSeconds" : self .settings .task_timeout
281+ }
251282
252283 # submit the job
253284 try :
254- submitted = self .aws .submit_job (** job_params )
285+ submitted = self .batch_client .submit_job (** job_params )
255286 self .logger .debug (
256287 "AWS Batch job submitted with queue {}, jobId {} and tags {}" .format (
257- self .task_queue , job ["jobId" ], self .tags
288+ self .settings . job_queue , submitted ["jobId" ], self . settings .tags
258289 )
259290 )
260291 except Exception as e :
@@ -274,13 +305,7 @@ def run_job(self, job: JobExecutorInterface):
274305 def _generate_snakemake_command (self , job : JobExecutorInterface ) -> str :
275306 """generates the snakemake command for the job"""
276307 exec_job = self .format_job_exec (job )
277- command = list (filter (None , exec_job .replace ('"' , "" ).split (" " )))
278- return_command = ["sh" , "-c" ]
279- snakemake_run_command = "cd {}/{} && {}" .format (
280- self .mount_path , self .efs_project_path , " " .join (command )
281- )
282- return_command .append (snakemake_run_command )
283- return return_command
308+ return ["sh" , "-c" , shlex .quote (exec_job )]
284309
285310 async def check_active_jobs (
286311 self , active_jobs : List [SubmittedJobInfo ]
@@ -333,7 +358,7 @@ def _get_job_status(self, job: SubmittedJobInfo) -> (int, Optional[str]):
333358 ]
334359 exit_code = None
335360 log_stream_name = None
336- job_desc = self ._describer .describe (self .aws , job .external_jobid , 1 )
361+ job_desc = self ._describer .describe (self .batch_client , job .external_jobid , 1 )
337362 job_status = job_desc ["status" ]
338363
339364 # set log stream name if not none
@@ -392,7 +417,7 @@ def _terminate_job(self, job: SubmittedJobInfo):
392417 """terminate job from submitted job info"""
393418 try :
394419 self .logger .debug (f"terminating job { job .external_jobid } " )
395- self .aws .terminate_job (
420+ self .batch_client .terminate_job (
396421 jobId = job .external_jobid ,
397422 reason = "terminated by snakemake" ,
398423 )
@@ -407,7 +432,7 @@ def _deregister_job(self, job: SubmittedJobInfo):
407432 try :
408433 job_def_arn = job .aux ["jobDefArn" ]
409434 self .logger .debug (f"de-registering Batch job definition { job_def_arn } " )
410- self .aws .deregister_job_definition (jobDefinition = job_def_arn )
435+ self .batch_client .deregister_job_definition (jobDefinition = job_def_arn )
411436 except Exception as e :
412437 # AWS expires job definitions after 6mo
413438 # so failing to delete them isn't fatal
0 commit comments