16
16
AbstractEvaluator ,
17
17
fit_and_suppress_warnings
18
18
)
19
+ from autoPyTorch .evaluation .utils import DisableFileOutputParameters
19
20
from autoPyTorch .pipeline .components .training .metrics .base import autoPyTorchMetric
20
21
from autoPyTorch .utils .common import subsampler
21
22
from autoPyTorch .utils .hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
@@ -33,7 +34,7 @@ def __init__(self, backend: Backend, queue: Queue,
33
34
num_run : Optional [int ] = None ,
34
35
include : Optional [Dict [str , Any ]] = None ,
35
36
exclude : Optional [Dict [str , Any ]] = None ,
36
- disable_file_output : Union [bool , List ] = False ,
37
+ disable_file_output : Optional [ List [ Union [str , DisableFileOutputParameters ]]] = None ,
37
38
init_params : Optional [Dict [str , Any ]] = None ,
38
39
logger_port : Optional [int ] = None ,
39
40
keep_models : Optional [bool ] = None ,
@@ -241,14 +242,11 @@ def file_output(
241
242
)
242
243
243
244
# Abort if we don't want to output anything.
244
- if hasattr (self , 'disable_file_output' ):
245
- if self .disable_file_output :
246
- return None , {}
247
- else :
248
- self .disabled_file_outputs = []
245
+ if 'all' in self .disable_file_output :
246
+ return None , {}
249
247
250
- if hasattr (self , 'pipeline' ) and self . pipeline is not None :
251
- if 'pipeline' not in self .disabled_file_outputs :
248
+ if getattr (self , 'pipeline' , None ) is not None :
249
+ if 'pipeline' not in self .disable_file_output :
252
250
pipeline = self .pipeline
253
251
else :
254
252
pipeline = None
@@ -265,11 +263,11 @@ def file_output(
265
263
ensemble_predictions = None ,
266
264
valid_predictions = (
267
265
Y_valid_pred if 'y_valid' not in
268
- self .disabled_file_outputs else None
266
+ self .disable_file_output else None
269
267
),
270
268
test_predictions = (
271
269
Y_test_pred if 'y_test' not in
272
- self .disabled_file_outputs else None
270
+ self .disable_file_output else None
273
271
),
274
272
)
275
273
@@ -287,8 +285,8 @@ def eval_function(
287
285
num_run : int ,
288
286
include : Optional [Dict [str , Any ]],
289
287
exclude : Optional [Dict [str , Any ]],
290
- disable_file_output : Union [bool , List ],
291
288
output_y_hat_optimization : bool = False ,
289
+ disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
292
290
pipeline_config : Optional [Dict [str , Any ]] = None ,
293
291
budget_type : str = None ,
294
292
init_params : Optional [Dict [str , Any ]] = None ,
@@ -297,14 +295,75 @@ def eval_function(
297
295
search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
298
296
instance : str = None ,
299
297
) -> None :
298
+ """
299
+ This closure allows the communication between the ExecuteTaFuncWithQueue and the
300
+ pipeline trainer (TrainEvaluator).
301
+
302
+ Fundamentally, smac calls the ExecuteTaFuncWithQueue.run() method, which internally
303
+ builds a TrainEvaluator. The TrainEvaluator builds a pipeline, stores the output files
304
+ to disc via the backend, and puts the performance result of the run in the queue.
305
+
306
+
307
+ Attributes:
308
+ backend (Backend):
309
+ An object to interface with the disk storage. In particular, allows to
310
+ access the train and test datasets
311
+ queue (Queue):
312
+ Each worker available will instantiate an evaluator, and after completion,
313
+ it will return the evaluation result via a multiprocessing queue
314
+ metric (autoPyTorchMetric):
315
+ A scorer object that is able to evaluate how good a pipeline was fit. It
316
+ is a wrapper on top of the actual score method (a wrapper on top of scikit
317
+ lean accuracy for example) that formats the predictions accordingly.
318
+ budget: (float):
319
+ The amount of epochs/time a configuration is allowed to run.
320
+ budget_type (str):
321
+ The budget type, which can be epochs or time
322
+ pipeline_config (Optional[Dict[str, Any]]):
323
+ Defines the content of the pipeline being evaluated. For example, it
324
+ contains pipeline specific settings like logging name, or whether or not
325
+ to use tensorboard.
326
+ config (Union[int, str, Configuration]):
327
+ Determines the pipeline to be constructed.
328
+ seed (int):
329
+ A integer that allows for reproducibility of results
330
+ output_y_hat_optimization (bool):
331
+ Whether this worker should output the target predictions, so that they are
332
+ stored on disk. Fundamentally, the resampling strategy might shuffle the
333
+ Y_train targets, so we store the split in order to re-use them for ensemble
334
+ selection.
335
+ num_run (Optional[int]):
336
+ An identifier of the current configuration being fit. This number is unique per
337
+ configuration.
338
+ include (Optional[Dict[str, Any]]):
339
+ An optional dictionary to include components of the pipeline steps.
340
+ exclude (Optional[Dict[str, Any]]):
341
+ An optional dictionary to exclude components of the pipeline steps.
342
+ disable_file_output (Union[bool, List[str]]):
343
+ By default, the model, it's predictions and other metadata is stored on disk
344
+ for each finished configuration. This argument allows the user to skip
345
+ saving certain file type, for example the model, from being written to disk.
346
+ init_params (Optional[Dict[str, Any]]):
347
+ Optional argument that is passed to each pipeline step. It is the equivalent of
348
+ kwargs for the pipeline steps.
349
+ logger_port (Optional[int]):
350
+ Logging is performed using a socket-server scheme to be robust against many
351
+ parallel entities that want to write to the same file. This integer states the
352
+ socket port for the communication channel. If None is provided, a traditional
353
+ logger is used.
354
+ instance (str):
355
+ An instance on which to evaluate the current pipeline. By default we work
356
+ with a single instance, being the provided X_train, y_train of a single dataset.
357
+ This instance is a compatibility argument for SMAC, that is capable of working
358
+ with multiple datasets at the same time.
359
+ """
300
360
evaluator = FitEvaluator (
301
361
backend = backend ,
302
362
queue = queue ,
303
363
metric = metric ,
304
364
configuration = config ,
305
365
seed = seed ,
306
366
num_run = num_run ,
307
- output_y_hat_optimization = output_y_hat_optimization ,
308
367
include = include ,
309
368
exclude = exclude ,
310
369
disable_file_output = disable_file_output ,
0 commit comments