@@ -214,22 +214,11 @@ def __init__(
214214 mlflow_tracking_uri : str | None = None ,
215215 ** kwargs : Any ,
216216 ):
217- logger .info (f"AutoRunner using work directory { work_dir } " )
218- os .makedirs (work_dir , exist_ok = True )
219-
220- self .work_dir = os .path .abspath (work_dir )
221- self .data_src_cfg = dict ()
222- self .data_src_cfg_name = os .path .join (self .work_dir , "input.yaml" )
223- self .algos = algos
224- self .templates_path_or_url = templates_path_or_url
225- self .allow_skip = allow_skip
226- self .mlflow_tracking_uri = mlflow_tracking_uri
227- self .kwargs = deepcopy (kwargs )
228-
229- if input is None and os .path .isfile (self .data_src_cfg_name ):
230- input = self .data_src_cfg_name
217+ if input is None and os .path .isfile (os .path .join (os .path .abspath (work_dir ), "input.yaml" )):
218+ input = os .path .join (os .path .abspath (work_dir ), "input.yaml" )
231219 logger .info (f"Input config is not provided, using the default { input } " )
232220
221+ self .data_src_cfg = dict ()
233222 if isinstance (input , dict ):
234223 self .data_src_cfg = input
235224 elif isinstance (input , str ) and os .path .isfile (input ):
@@ -238,6 +227,51 @@ def __init__(
238227 else :
239228 raise ValueError (f"{ input } is not a valid file or dict" )
240229
230+ if "work_dir" in self .data_src_cfg : # override from config
231+ work_dir = self .data_src_cfg ["work_dir" ]
232+ self .work_dir = os .path .abspath (work_dir )
233+
234+ logger .info (f"AutoRunner using work directory { self .work_dir } " )
235+ os .makedirs (self .work_dir , exist_ok = True )
236+ self .data_src_cfg_name = os .path .join (self .work_dir , "input.yaml" )
237+
238+ self .algos = algos
239+ self .templates_path_or_url = templates_path_or_url
240+ self .allow_skip = allow_skip
241+
242+ # cache.yaml
243+ self .not_use_cache = not_use_cache
244+ self .cache_filename = os .path .join (self .work_dir , "cache.yaml" )
245+ self .cache = self .read_cache ()
246+ self .export_cache ()
247+
248+ # determine if we need to analyze, algo_gen or train from cache, unless manually provided
249+ self .analyze = not self .cache ["analyze" ] if analyze is None else analyze
250+ self .algo_gen = not self .cache ["algo_gen" ] if algo_gen is None else algo_gen
251+ self .train = train
252+ self .ensemble = ensemble # last step, no need to check
253+ self .hpo = hpo and has_nni
254+ self .hpo_backend = hpo_backend
255+ self .mlflow_tracking_uri = mlflow_tracking_uri
256+ self .kwargs = deepcopy (kwargs )
257+
258+ # parse input config for AutoRunner param overrides
259+ for param in [
260+ "analyze" ,
261+ "algo_gen" ,
262+ "train" ,
263+ "hpo" ,
264+ "ensemble" ,
265+ "not_use_cache" ,
266+ "allow_skip" ,
267+ ]: # override from config
268+ if param in self .data_src_cfg and isinstance (self .data_src_cfg [param ], bool ):
269+ setattr (self , param , self .data_src_cfg [param ]) # e.g. self.analyze = self.data_src_cfg["analyze"]
270+
271+ for param in ["algos" , "hpo_backend" , "templates_path_or_url" , "mlflow_tracking_uri" ]: # override from config
272+ if param in self .data_src_cfg :
273+ setattr (self , param , self .data_src_cfg [param ]) # e.g. self.algos = self.data_src_cfg["algos"]
274+
241275 missing_keys = {"dataroot" , "datalist" , "modality" }.difference (self .data_src_cfg .keys ())
242276 if len (missing_keys ) > 0 :
243277 raise ValueError (f"Config keys are missing { missing_keys } " )
@@ -256,6 +290,8 @@ def __init__(
256290
257291 # inspect and update folds
258292 num_fold = self .inspect_datalist_folds (datalist_filename = datalist_filename )
293+ if "num_fold" in self .data_src_cfg :
294+ num_fold = int (self .data_src_cfg ["num_fold" ]) # override from config
259295
260296 self .data_src_cfg ["datalist" ] = datalist_filename # update path to a version in work_dir and save user input
261297 ConfigParser .export_config_file (
@@ -266,17 +302,6 @@ def __init__(
266302 self .datastats_filename = os .path .join (self .work_dir , "datastats.yaml" )
267303 self .datalist_filename = datalist_filename
268304
269- self .not_use_cache = not_use_cache
270- self .cache_filename = os .path .join (self .work_dir , "cache.yaml" )
271- self .cache = self .read_cache ()
272- self .export_cache ()
273-
274- # determine if we need to analyze, algo_gen or train from cache, unless manually provided
275- self .analyze = not self .cache ["analyze" ] if analyze is None else analyze
276- self .algo_gen = not self .cache ["algo_gen" ] if algo_gen is None else algo_gen
277- self .train = train
278- self .ensemble = ensemble # last step, no need to check
279-
280305 self .set_training_params ()
281306 self .set_device_info ()
282307 self .set_prediction_params ()
@@ -288,9 +313,9 @@ def __init__(
288313 self .gpu_customization_specs : dict [str , Any ] = {}
289314
290315 # hpo
291- if hpo_backend .lower () != "nni" :
316+ if self . hpo_backend .lower () != "nni" :
292317 raise NotImplementedError ("HPOGen backend only supports NNI" )
293- self .hpo = hpo and has_nni
318+ self .hpo = self . hpo and has_nni
294319 self .set_hpo_params ()
295320 self .search_space : dict [str , dict [str , Any ]] = {}
296321 self .hpo_tasks = 0
0 commit comments