1717import  uuid 
1818from  pprint  import  pformat 
1919from  urllib .parse  import  urlparse 
20- from  typing  import  List ,  Generator , Optional 
20+ from  typing  import  AsyncGenerator ,  List , Optional 
2121
2222from  snakemake_interface_executor_plugins .executors .base  import  SubmittedJobInfo 
2323from  snakemake_interface_executor_plugins .executors .remote  import  RemoteExecutor 
2424from  snakemake_interface_executor_plugins  import  ExecutorSettingsBase , CommonSettings 
25- from  snakemake_interface_executor_plugins .workflow  import  WorkflowExecutorInterface 
26- from  snakemake_interface_executor_plugins .logging  import  LoggerExecutorInterface 
2725from  snakemake_interface_executor_plugins .jobs  import  (
2826    JobExecutorInterface ,
2927)
4240
4341from  snakemake .remote .AzBlob  import  AzureStorageHelper 
4442import  msrest .authentication  as  msa 
45- from  snakemake_executor_plugin_azure_batch .common  import  bytesto 
4643
4744
4845# Optional: 
@@ -91,29 +88,17 @@ class ExecutorSettings(ExecutorSettingsBase):
9188    # filesystem (True) or not (False). 
9289    # This is e.g. the case for cloud execution. 
9390    implies_no_shared_fs = True ,
91+     pass_default_storage_provider_args = True ,
92+     pass_default_resources_args = True ,
93+     pass_envvar_declarations_to_cmd = False ,
94+     auto_deploy_default_storage_provider = True ,
9495)
9596
9697
9798# Required: 
9899# Implementation of your executor 
99100class  Executor (RemoteExecutor ):
100-     def  __init__ (
101-         self ,
102-         workflow : WorkflowExecutorInterface ,
103-         logger : LoggerExecutorInterface ,
104-     ):
105-         super ().__init__ (
106-             workflow ,
107-             logger ,
108-             # configure behavior of RemoteExecutor below 
109-             # whether arguments for setting the remote provider shall  be passed to jobs 
110-             pass_default_storage_provider_args = True ,
111-             # whether arguments for setting default resources shall be passed to jobs 
112-             pass_default_resources_args = True ,
113-             # whether environment variables shall be passed to jobs 
114-             pass_envvar_declarations_to_cmd = True ,
115-         )
116- 
101+     def  __post_init__ (self ):
117102        AZURE_BATCH_RESOURCE_ENDPOINT  =  "https://batch.core.windows.net/" 
118103
119104        # Here we validate that az blob credential is SAS 
@@ -124,14 +109,12 @@ def __init__(
124109        # TODO this does not work if the remote is used without default_remote_prefix 
125110        # get container from remote prefix 
126111        self .prefix_container  =  str .split (
127-             workflow .storage_settings .default_remote_prefix , "/" 
112+             self . workflow .storage_settings .default_remote_prefix , "/" 
128113        )[0 ]
129114
130115        # setup batch configuration sets self.az_batch_config 
131116        self .batch_config  =  AzBatchConfig (self .workflow .executor_settings .account_url )
132-         logger .debug (f"AzBatchConfig: { self .mask_batch_config_as_string ()}  )
133- 
134-         self .workflow  =  workflow 
117+         self .logger .debug (f"AzBatchConfig: { self .mask_batch_config_as_string ()}  )
135118
136119        # handle case on OSX with /var/ symlinked to /private/var/ causing 
137120        # issues with workdir not matching other workflow file dirs 
@@ -156,28 +139,19 @@ def __init__(
156139        # enable autoscale flag 
157140        self .az_batch_enable_autoscale  =  self .workflow .executor_settings .autoscale 
158141
159-         # Package workflow sources files and upload to storage 
160-         self ._build_packages  =  set ()
161-         targz  =  self ._generate_build_source_package ()
162- 
163-         # removed after job failure/success 
164-         self .resource_file  =  self ._upload_build_source_package (
165-             targz , resource_prefix = self .batch_config .resource_file_prefix 
166-         )
167- 
168142        # authenticate batch client from SharedKeyCredentials 
169143        if  (
170144            self .batch_config .batch_account_key  is  not None 
171145            and  self .batch_config .managed_identity_client_id  is  None 
172146        ):
173-             logger .debug ("Using batch account key for authentication..." )
147+             self . logger .debug ("Using batch account key for authentication..." )
174148            creds  =  SharedKeyCredentials (
175149                self .batch_config .batch_account_name ,
176150                self .batch_config .batch_account_key ,
177151            )
178152        # else authenticate with managed indentity client id 
179153        elif  self .batch_config .managed_identity_client_id  is  not None :
180-             logger .debug ("Using managed identity batch authentication..." )
154+             self . logger .debug ("Using managed identity batch authentication..." )
181155            creds  =  DefaultAzureCredential (
182156                managed_identity_client_id = self .batch_config .managed_identity_client_id 
183157            )
@@ -192,7 +166,7 @@ def __init__(
192166        if  self .batch_config .managed_identity_resource_id  is  not None :
193167            self .batch_mgmt_client  =  BatchManagementClient (
194168                credential = DefaultAzureCredential (
195-                     managed_identity_client_id = self .batch_config .managed_identity_client_id  # noqa 
169+                     managed_identity_client_id = self .batch_config .managed_identity_client_id    # noqa 
196170                ),
197171                subscription_id = self .batch_config .subscription_id ,
198172            )
@@ -210,12 +184,6 @@ def shutdown(self):
210184        self .logger .debug ("Deleting AzBatch pool" )
211185        self .batch_client .pool .delete (self .pool_id )
212186
213-         self .logger .debug ("Deleting workflow sources from blob" )
214- 
215-         self .azblob_helper .delete_from_container (
216-             self .prefix_container , self .resource_file .file_path 
217-         )
218- 
219187        super ().shutdown ()
220188
221189    def  run_job (self , job : JobExecutorInterface ):
@@ -238,10 +206,7 @@ def run_job(self, job: JobExecutorInterface):
238206                continue 
239207
240208        exec_job  =  self .format_job_exec (job )
241-         exec_job  =  (
242-             f"/bin/bash -c 'tar xzf { self .resource_file .file_path }  
243-             f"{ shlex .quote (exec_job )}  
244-         )
209+         exec_job  =  f"/bin/bash -c '{ shlex .quote (exec_job )}  
245210
246211        # A string that uniquely identifies the Task within the Job. 
247212        task_uuid  =  str (uuid .uuid1 ())
@@ -267,7 +232,6 @@ def run_job(self, job: JobExecutorInterface):
267232            id = task_id ,
268233            command_line = exec_job ,
269234            container_settings = task_container_settings ,
270-             resource_files = [self .resource_file ],  # Snakefile, envs, yml files etc. 
271235            user_identity = batchmodels .UserIdentity (auto_user = user ),
272236            environment_settings = envsettings ,
273237        )
@@ -284,7 +248,7 @@ def run_job(self, job: JobExecutorInterface):
284248
285249    async  def  check_active_jobs (
286250        self , active_jobs : List [SubmittedJobInfo ]
287-     ) ->  Generator [SubmittedJobInfo , None , None ]:
251+     ) ->  AsyncGenerator [SubmittedJobInfo , None , None ]:
288252        # Check the status of active jobs. 
289253
290254        # You have to iterate over the given list active_jobs. 
@@ -321,14 +285,6 @@ async def check_active_jobs(
321285                    )
322286                )
323287
324-                 def  print_output ():
325-                     self .logger .debug (
326-                         "task {}: stderr='{}'\n " .format (batch_job .task_id , stderr )
327-                     )
328-                     self .logger .debug (
329-                         "task {}: stdout='{}'\n " .format (batch_job .task_id , stdout )
330-                     )
331- 
332288                if  (
333289                    task .execution_info .result 
334290                    ==  batchmodels .TaskExecutionResult .failure 
@@ -645,94 +601,6 @@ def validate_az_blob_credential_is_sas():
645601                    "AZ_BLOB_CREDENTIAL is not a valid storage account SAS token." 
646602                )
647603
648-     # from google_lifesciences.py 
649-     def  _set_workflow_sources (self ):
650-         """We only add files from the working directory that are config related 
651-         (e.g., the Snakefile or a config.yml equivalent), or checked into git. 
652-         """ 
653-         self .workflow_sources  =  []
654- 
655-         for  wfs  in  self .dag .get_sources ():
656-             if  os .path .isdir (wfs ):
657-                 for  dirpath , dirnames , filenames  in  os .walk (wfs ):
658-                     self .workflow_sources .extend (
659-                         [
660-                             self ._check_source_size (os .path .join (dirpath , f ))
661-                             for  f  in  filenames 
662-                         ]
663-                     )
664-             else :
665-                 self .workflow_sources .append (
666-                     self ._check_source_size (os .path .abspath (wfs ))
667-                 )
668- 
669-     # from google_lifesciences.py 
670-     def  _generate_build_source_package (self ):
671-         """in order for the instance to access the working directory in storage, 
672-         we need to upload it. This file is cleaned up at the end of the run. 
673-         We do this, and then obtain from the instance and extract. 
674-         """ 
675-         # Workflow sources for cloud executor must all be under same workdir root 
676-         for  filename  in  self .workflow_sources :
677-             if  self .workdir  not  in filename :
678-                 raise  WorkflowError (
679-                     "All source files must be present in the working directory, " 
680-                     "{workdir} to be uploaded to a build package that respects " 
681-                     "relative paths, but {filename} was found outside of this " 
682-                     "directory. Please set your working directory accordingly, " 
683-                     "and the path of your Snakefile to be relative to it." .format (
684-                         workdir = self .workdir , filename = filename 
685-                     )
686-                 )
687- 
688-         # We will generate a tar.gz package, renamed by hash 
689-         tmpname  =  next (tempfile ._get_candidate_names ())
690-         targz  =  os .path .join (tempfile .gettempdir (), f"snakemake-{ tmpname }  )
691-         tar  =  tarfile .open (targz , "w:gz" )
692- 
693-         # Add all workflow_sources files 
694-         for  filename  in  self .workflow_sources :
695-             arcname  =  filename .replace (self .workdir  +  os .path .sep , "" )
696-             tar .add (filename , arcname = arcname )
697-         self .logger .debug (
698-             f"Created { targz } { self .workflow_sources }  
699-         )
700-         tar .close ()
701- 
702-         # Rename based on hash, in case user wants to save cache 
703-         hasher  =  hashlib .sha256 ()
704-         hasher .update (open (targz , "rb" ).read ())
705-         sha256  =  hasher .hexdigest ()
706- 
707-         hash_tar  =  os .path .join (
708-             self .workflow .persistence .aux_path , f"workdir-{ sha256 }  
709-         )
710- 
711-         # Only copy if we don't have it yet, clean up if we do 
712-         if  not  os .path .exists (hash_tar ):
713-             shutil .move (targz , hash_tar )
714-         else :
715-             os .remove (targz )
716- 
717-         # We will clean these all up at shutdown 
718-         self ._build_packages .add (hash_tar )
719- 
720-         return  hash_tar 
721- 
722-     def  _upload_build_source_package (self , targz , resource_prefix = "" ):
723-         """given a .tar.gz created for a workflow, upload it to the blob 
724-         storage account, only if the blob doesn't already exist. 
725-         """ 
726-         blob_name  =  os .path .join (resource_prefix , os .path .basename (targz ))
727- 
728-         # upload blob to storage using storage helper 
729-         bc  =  self .azblob_helper .upload_to_azure_storage (
730-             self .prefix_container , targz , blob_name = blob_name 
731-         )
732- 
733-         # return resource file 
734-         return  batchmodels .ResourceFile (http_url = bc .url , file_path = blob_name )
735- 
736604    # from https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py # noqa 
737605    @staticmethod  
738606    def  _read_stream_as_string (stream , encoding ):
@@ -765,19 +633,6 @@ def _get_task_output(self, job_id, task_id, stdout_or_stderr, encoding=None):
765633
766634        return  content 
767635
768-     def  _check_source_size (self , filename , warning_size_gb = 0.2 ):
769-         """A helper function to check the filesize, and return the file 
770-         to the calling function Additionally, given that we encourage these 
771-         packages to be small, we set a warning at 200MB (0.2GB). 
772-         """ 
773-         gb  =  bytesto (os .stat (filename ).st_size , "g" )
774-         if  gb  >  warning_size_gb :
775-             self .logger .warning (
776-                 f"File { filename } { gb } { warning_size_gb }  
777-                 "GB suggested size. Consider uploading larger files to storage first." 
778-             )
779-         return  filename 
780- 
781636
782637class  AzBatchConfig :
783638    def  __init__ (self , batch_account_url : str ):
0 commit comments