Skip to content

Commit 853f16d

Browse files
committed
refactor: container options, resource reqs
1 parent 7934779 commit 853f16d

File tree

3 files changed

+82
-38
lines changed

3 files changed

+82
-38
lines changed

snakemake_executor_plugin_aws_batch/__init__.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55
import heapq
66
import botocore
7+
import shlex
78
import time
89
import threading
910
from typing import List, Generator, Optional
@@ -27,6 +28,16 @@
2728
# of None or anything else that makes sense in your case.
2829
@dataclass
2930
class ExecutorSettings(ExecutorSettingsBase):
31+
access_key_id: Optional[int] = field(
32+
default=None,
33+
metadata={"help": "AWS access key id", "env_var": True, "required": False},
34+
repr=False,
35+
)
36+
access_key: Optional[int] = field(
37+
default=None,
38+
metadata={"help": "AWS access key", "env_var": True, "required": False},
39+
repr=False,
40+
)
3041
region: Optional[int] = field(
3142
default=None,
3243
metadata={
@@ -58,23 +69,23 @@ class ExecutorSettings(ExecutorSettingsBase):
5869
},
5970
),
6071
)
61-
task_queue: Optional[str] = field(
72+
job_queue: Optional[str] = field(
6273
default=None,
6374
metadata={
6475
"help": "The AWS Batch task queue ARN used for running tasks",
65-
"env_var": False,
76+
"env_var": True,
6677
"required": True,
6778
},
6879
)
69-
workflow_role: Optional[str] = field(
80+
execution_role: Optional[str] = field(
7081
default=None,
7182
metadata={
72-
"help": "The AWS role that is used for running the tasks",
73-
"env_var": False,
83+
"help": "The AWS execution role ARN that is used for running the tasks",
84+
"env_var": True,
7485
"required": True,
7586
},
7687
)
77-
tags: Optional[List[str]] = field(
88+
tags: Optional[dict] = field(
7889
default=None,
7990
metadata={
8091
"help": (
@@ -105,7 +116,9 @@ class ExecutorSettings(ExecutorSettingsBase):
105116
# plugins (snakemake-executor-plugin-dryrun, snakemake-executor-plugin-local)
106117
# are expected to specify False here.
107118
non_local_exec=True,
119+
# whether the executor implies to not have a shared file system
108120
implies_no_shared_fs=True,
121+
# whether to deploy workflow sources to default storage provider before execution
109122
job_deploy_sources=True,
110123
# whether arguments for setting the storage provider shall be passed to jobs
111124
pass_default_storage_provider_args=True,
@@ -136,15 +149,21 @@ def __post_init__(self):
136149

137150
# keep track of job definitions
138151
self.created_job_defs = list()
152+
self.mount_path = None
153+
self._describer = BatchJobDescriber()
139154

140155
# init batch client
141156
try:
142-
self.aws = boto3.Session().client( # Session() needed for thread safety
143-
"batch",
144-
region_name=self.settings.region,
145-
config=botocore.config.Config(
146-
retries={"max_attempts": 5, "mode": "standard"}
147-
),
157+
self.batch_client = (
158+
boto3.Session().client( # Session() needed for thread safety
159+
"batch",
160+
aws_access_key_id=self.settings.access_key_id,
161+
aws_secret_access_key=self.settings.access_key,
162+
region_name=self.settings.region,
163+
config=botocore.config.Config(
164+
retries={"max_attempts": 5, "mode": "standard"}
165+
),
166+
)
148167
)
149168
except Exception as e:
150169
raise WorkflowError(e)
@@ -194,19 +213,25 @@ def run_job(self, job: JobExecutorInterface):
194213
job_definition_type = "container"
195214

196215
# get job resources or default
197-
vcpu = job.resources.get("_cores", str(1))
198-
mem = job.resources.get("mem_mb", str(1024))
216+
vcpu = str(job.resources.get("_cores", str(1)))
217+
mem = str(job.resources.get("mem_mb", str(2048)))
199218

200219
# job definition container properties
201220
container_properties = {
202-
"executionRoleArn": self.settings.workflow_role,
203221
"command": ["snakemake"],
204222
"image": self.container_image,
205-
"privileged": True,
223+
# fargate required privileged False
224+
"privileged": False,
206225
"resourceRequirements": [
207-
{"type": "VCPU", "value": "1"},
208-
{"type": "MEMORY", "value": "1024"},
226+
# resource requirements have to be compatible
227+
# see: https://docs.aws.amazon.com/batch/latest/APIReference/API_ResourceRequirement.html # noqa
228+
{"type": "VCPU", "value": vcpu},
229+
{"type": "MEMORY", "value": mem},
209230
],
231+
"networkConfiguration": {
232+
"assignPublicIp": "ENABLED",
233+
},
234+
"executionRoleArn": self.settings.execution_role,
210235
}
211236

212237
# TODO: or not todo ?
@@ -216,14 +241,16 @@ def run_job(self, job: JobExecutorInterface):
216241
# ) = self._prepare_mounts()
217242

218243
# register the job definition
244+
tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict()
219245
try:
220-
job_def = self.aws.register_job_definition(
246+
job_def = self.batch_client.register_job_definition(
221247
jobDefinitionName=job_definition_name,
222248
type=job_definition_type,
223249
containerProperties=container_properties,
224-
tags=self.settings.tags,
250+
platformCapabilities=["FARGATE"],
251+
tags=tags,
225252
)
226-
self.created_job_defs.append(self.job_def)
253+
self.created_job_defs.append(job_def)
227254
except Exception as e:
228255
raise WorkflowError(e)
229256

@@ -232,7 +259,7 @@ def run_job(self, job: JobExecutorInterface):
232259
# configure job parameters
233260
job_params = {
234261
"jobName": job_name,
235-
"jobQueue": self.settings.task_queue,
262+
"jobQueue": self.settings.job_queue,
236263
"jobDefinition": "{}:{}".format(
237264
job_def["jobDefinitionName"], job_def["revision"]
238265
),
@@ -243,18 +270,22 @@ def run_job(self, job: JobExecutorInterface):
243270
{"type": "MEMORY", "value": mem},
244271
],
245272
},
246-
"tags": self.settings.tags,
247273
}
248274

275+
if self.settings.tags:
276+
job_params["tags"] = self.settings.tags
277+
249278
if self.settings.task_timeout is not None:
250-
job_params["timeout"] = {"attemptDurationSeconds": self.task_timeout}
279+
job_params["timeout"] = {
280+
"attemptDurationSeconds": self.settings.task_timeout
281+
}
251282

252283
# submit the job
253284
try:
254-
submitted = self.aws.submit_job(**job_params)
285+
submitted = self.batch_client.submit_job(**job_params)
255286
self.logger.debug(
256287
"AWS Batch job submitted with queue {}, jobId {} and tags {}".format(
257-
self.task_queue, job["jobId"], self.tags
288+
self.settings.job_queue, submitted["jobId"], self.settings.tags
258289
)
259290
)
260291
except Exception as e:
@@ -274,13 +305,7 @@ def run_job(self, job: JobExecutorInterface):
274305
def _generate_snakemake_command(self, job: JobExecutorInterface) -> str:
275306
"""generates the snakemake command for the job"""
276307
exec_job = self.format_job_exec(job)
277-
command = list(filter(None, exec_job.replace('"', "").split(" ")))
278-
return_command = ["sh", "-c"]
279-
snakemake_run_command = "cd {}/{} && {}".format(
280-
self.mount_path, self.efs_project_path, " ".join(command)
281-
)
282-
return_command.append(snakemake_run_command)
283-
return return_command
308+
return ["sh", "-c", shlex.quote(exec_job)]
284309

285310
async def check_active_jobs(
286311
self, active_jobs: List[SubmittedJobInfo]
@@ -333,7 +358,7 @@ def _get_job_status(self, job: SubmittedJobInfo) -> (int, Optional[str]):
333358
]
334359
exit_code = None
335360
log_stream_name = None
336-
job_desc = self._describer.describe(self.aws, job.external_jobid, 1)
361+
job_desc = self._describer.describe(self.batch_client, job.external_jobid, 1)
337362
job_status = job_desc["status"]
338363

339364
# set log stream name if not none
@@ -392,7 +417,7 @@ def _terminate_job(self, job: SubmittedJobInfo):
392417
"""terminate job from submitted job info"""
393418
try:
394419
self.logger.debug(f"terminating job {job.external_jobid}")
395-
self.aws.terminate_job(
420+
self.batch_client.terminate_job(
396421
jobId=job.external_jobid,
397422
reason="terminated by snakemake",
398423
)
@@ -407,7 +432,7 @@ def _deregister_job(self, job: SubmittedJobInfo):
407432
try:
408433
job_def_arn = job.aux["jobDefArn"]
409434
self.logger.debug(f"de-registering Batch job definition {job_def_arn}")
410-
self.aws.deregister_job_definition(jobDefinition=job_def_arn)
435+
self.batch_client.deregister_job_definition(jobDefinition=job_def_arn)
411436
except Exception as e:
412437
# AWS expires job definitions after 6mo
413438
# so failing to delete them isn't fatal

tests/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ def get_executor(self) -> str:
1414
def get_executor_settings(self) -> Optional[ExecutorSettingsBase]:
1515
# instantiate ExecutorSettings of this plugin as appropriate
1616
return ExecutorSettings(
17-
account_url=os.getenv("SNAKEMAKE_AWS_"),
18-
account_key=os.getenv("SNAKEMAKE_AWS_"),
17+
access_key_id=os.getenv("SNAKEMAKE_AWS_BATCH_ACCESS_KEY_ID"),
18+
access_key=os.getenv("SNAKEMAKE_AWS_BATCH_ACCESS_KEY"),
19+
region=os.environ.get("SNAKEMAKE_AWS_BATCH_REGION", "us-east-1"),
20+
job_queue=os.environ.get("SNAKEMAKE_AWS_BATCH_JOB_QUEUE"),
21+
execution_role=os.environ.get("SNAKEMAKE_AWS_BATCH_EXECUTION_ROLE"),
1922
)
2023

2124
def get_assume_shared_fs(self) -> bool:

tests/tests_mocked_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,21 @@
55
class TestWorkflowsMocked(TestWorkflowsBase):
66
__test__ = True
77

8+
@patch("boto3.client")
9+
# TODO: patch run_job internals
10+
@patch(
11+
"snakemake.dag.DAG.check_and_touch_output",
12+
new=AsyncMock(autospec=True),
13+
)
14+
@patch(
15+
"snakemake_storage_plugin_s3.StorageObject.managed_size",
16+
new=AsyncMock(autospec=True, return_value=0),
17+
)
18+
@patch(
19+
# mocking has to happen in the importing module, see
20+
# http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python
21+
"snakemake.jobs.wait_for_files",
22+
new=AsyncMock(autospec=True),
23+
)
824
def run_workflow(self, *args, **kwargs):
925
super().run_workflow(*args, **kwargs)

0 commit comments

Comments
 (0)