Skip to content

Commit 8787186

Browse files
committed
rebase and fix flake
1 parent 907b537 commit 8787186

File tree

5 files changed

+93
-33
lines changed

5 files changed

+93
-33
lines changed

autoPyTorch/api/base_task.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,14 @@ def _get_dataset_input_validator(
299299
y_train: Union[List, pd.DataFrame, np.ndarray],
300300
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
301301
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
302+
<<<<<<< HEAD
302303
resampling_strategy: Optional[ResamplingStrategies] = None,
304+
=======
305+
resampling_strategy: Optional[Union[
306+
CrossValTypes,
307+
HoldoutValTypes,
308+
NoResamplingStrategyTypes]] = None,
309+
>>>>>>> rebase and fix flake
303310
resampling_strategy_args: Optional[Dict[str, Any]] = None,
304311
dataset_name: Optional[str] = None,
305312
) -> Tuple[BaseDataset, BaseInputValidator]:
@@ -341,7 +348,14 @@ def get_dataset(
341348
y_train: Union[List, pd.DataFrame, np.ndarray],
342349
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
343350
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
351+
<<<<<<< HEAD
344352
resampling_strategy: Optional[ResamplingStrategies] = None,
353+
=======
354+
resampling_strategy: Optional[Union[
355+
CrossValTypes,
356+
HoldoutValTypes,
357+
NoResamplingStrategyTypes]] = None,
358+
>>>>>>> rebase and fix flake
345359
resampling_strategy_args: Optional[Dict[str, Any]] = None,
346360
dataset_name: Optional[str] = None,
347361
) -> BaseDataset:
@@ -1389,7 +1403,14 @@ def fit_pipeline(
13891403
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13901404
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13911405
dataset_name: Optional[str] = None,
1406+
<<<<<<< HEAD
13921407
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes]] = None,
1408+
=======
1409+
resampling_strategy: Optional[Union[
1410+
CrossValTypes,
1411+
HoldoutValTypes,
1412+
NoResamplingStrategyTypes]] = None,
1413+
>>>>>>> rebase and fix flake
13931414
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13941415
run_time_limit_secs: int = 60,
13951416
memory_limit: Optional[int] = None,
@@ -1513,7 +1534,6 @@ def fit_pipeline(
15131534
(BaseDataset):
15141535
Dataset created from the given tensors
15151536
"""
1516-
self.dataset_name = dataset.dataset_name
15171537

15181538
if dataset is None:
15191539
if (

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -374,19 +374,6 @@ def search(
374374
self
375375
376376
"""
377-
if dataset_name is None:
378-
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
379-
380-
# we have to create a logger for at this point for the validator
381-
self._logger = self._get_logger(dataset_name)
382-
383-
# Create a validator object to make sure that the data provided by
384-
# the user matches the autopytorch requirements
385-
self.InputValidator = TabularInputValidator(
386-
is_classification=True,
387-
logger_port=self._logger_port,
388-
)
389-
390377
self.dataset, self.InputValidator = self._get_dataset_input_validator(
391378
X_train=X_train,
392379
y_train=y_train,
@@ -404,9 +391,6 @@ def search(
404391
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
405392
)
406393

407-
if self.dataset is None:
408-
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
409-
410394
return self._search(
411395
dataset=self.dataset,
412396
optimize_metric=optimize_metric,

autoPyTorch/api/tabular_regression.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,6 @@ def search(
391391
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
392392
)
393393

394-
if self.dataset is None:
395-
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
396-
397394
return self._search(
398395
dataset=self.dataset,
399396
optimize_metric=optimize_metric,

autoPyTorch/evaluation/fit_evaluator.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AbstractEvaluator,
1717
fit_and_suppress_warnings
1818
)
19+
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
1920
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
2021
from autoPyTorch.utils.common import subsampler
2122
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
@@ -33,7 +34,7 @@ def __init__(self, backend: Backend, queue: Queue,
3334
num_run: Optional[int] = None,
3435
include: Optional[Dict[str, Any]] = None,
3536
exclude: Optional[Dict[str, Any]] = None,
36-
disable_file_output: Union[bool, List] = False,
37+
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
3738
init_params: Optional[Dict[str, Any]] = None,
3839
logger_port: Optional[int] = None,
3940
keep_models: Optional[bool] = None,
@@ -241,14 +242,11 @@ def file_output(
241242
)
242243

243244
# Abort if we don't want to output anything.
244-
if hasattr(self, 'disable_file_output'):
245-
if self.disable_file_output:
246-
return None, {}
247-
else:
248-
self.disabled_file_outputs = []
245+
if 'all' in self.disable_file_output:
246+
return None, {}
249247

250-
if hasattr(self, 'pipeline') and self.pipeline is not None:
251-
if 'pipeline' not in self.disabled_file_outputs:
248+
if getattr(self, 'pipeline', None) is not None:
249+
if 'pipeline' not in self.disable_file_output:
252250
pipeline = self.pipeline
253251
else:
254252
pipeline = None
@@ -265,11 +263,11 @@ def file_output(
265263
ensemble_predictions=None,
266264
valid_predictions=(
267265
Y_valid_pred if 'y_valid' not in
268-
self.disabled_file_outputs else None
266+
self.disable_file_output else None
269267
),
270268
test_predictions=(
271269
Y_test_pred if 'y_test' not in
272-
self.disabled_file_outputs else None
270+
self.disable_file_output else None
273271
),
274272
)
275273

@@ -287,8 +285,8 @@ def eval_function(
287285
num_run: int,
288286
include: Optional[Dict[str, Any]],
289287
exclude: Optional[Dict[str, Any]],
290-
disable_file_output: Union[bool, List],
291288
output_y_hat_optimization: bool = False,
289+
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
292290
pipeline_config: Optional[Dict[str, Any]] = None,
293291
budget_type: str = None,
294292
init_params: Optional[Dict[str, Any]] = None,
@@ -297,14 +295,75 @@ def eval_function(
297295
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
298296
instance: str = None,
299297
) -> None:
298+
"""
299+
This closure allows the communication between the ExecuteTaFuncWithQueue and the
300+
pipeline trainer (TrainEvaluator).
301+
302+
Fundamentally, smac calls the ExecuteTaFuncWithQueue.run() method, which internally
303+
builds a TrainEvaluator. The TrainEvaluator builds a pipeline, stores the output files
304+
to disc via the backend, and puts the performance result of the run in the queue.
305+
306+
307+
Attributes:
308+
backend (Backend):
309+
An object to interface with the disk storage. In particular, allows to
310+
access the train and test datasets
311+
queue (Queue):
312+
Each worker available will instantiate an evaluator, and after completion,
313+
it will return the evaluation result via a multiprocessing queue
314+
metric (autoPyTorchMetric):
315+
A scorer object that is able to evaluate how good a pipeline was fit. It
316+
is a wrapper on top of the actual score method (a wrapper on top of scikit
317+
lean accuracy for example) that formats the predictions accordingly.
318+
budget: (float):
319+
The amount of epochs/time a configuration is allowed to run.
320+
budget_type (str):
321+
The budget type, which can be epochs or time
322+
pipeline_config (Optional[Dict[str, Any]]):
323+
Defines the content of the pipeline being evaluated. For example, it
324+
contains pipeline specific settings like logging name, or whether or not
325+
to use tensorboard.
326+
config (Union[int, str, Configuration]):
327+
Determines the pipeline to be constructed.
328+
seed (int):
329+
A integer that allows for reproducibility of results
330+
output_y_hat_optimization (bool):
331+
Whether this worker should output the target predictions, so that they are
332+
stored on disk. Fundamentally, the resampling strategy might shuffle the
333+
Y_train targets, so we store the split in order to re-use them for ensemble
334+
selection.
335+
num_run (Optional[int]):
336+
An identifier of the current configuration being fit. This number is unique per
337+
configuration.
338+
include (Optional[Dict[str, Any]]):
339+
An optional dictionary to include components of the pipeline steps.
340+
exclude (Optional[Dict[str, Any]]):
341+
An optional dictionary to exclude components of the pipeline steps.
342+
disable_file_output (Union[bool, List[str]]):
343+
By default, the model, it's predictions and other metadata is stored on disk
344+
for each finished configuration. This argument allows the user to skip
345+
saving certain file type, for example the model, from being written to disk.
346+
init_params (Optional[Dict[str, Any]]):
347+
Optional argument that is passed to each pipeline step. It is the equivalent of
348+
kwargs for the pipeline steps.
349+
logger_port (Optional[int]):
350+
Logging is performed using a socket-server scheme to be robust against many
351+
parallel entities that want to write to the same file. This integer states the
352+
socket port for the communication channel. If None is provided, a traditional
353+
logger is used.
354+
instance (str):
355+
An instance on which to evaluate the current pipeline. By default we work
356+
with a single instance, being the provided X_train, y_train of a single dataset.
357+
This instance is a compatibility argument for SMAC, that is capable of working
358+
with multiple datasets at the same time.
359+
"""
300360
evaluator = FitEvaluator(
301361
backend=backend,
302362
queue=queue,
303363
metric=metric,
304364
configuration=config,
305365
seed=seed,
306366
num_run=num_run,
307-
output_y_hat_optimization=output_y_hat_optimization,
308367
include=include,
309368
exclude=exclude,
310369
disable_file_output=disable_file_output,

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,10 @@ def eval_train_function(
428428
budget: float,
429429
config: Optional[Configuration],
430430
seed: int,
431-
output_y_hat_optimization: bool,
432431
num_run: int,
433432
include: Optional[Dict[str, Any]],
434433
exclude: Optional[Dict[str, Any]],
434+
output_y_hat_optimization: bool,
435435
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
436436
pipeline_config: Optional[Dict[str, Any]] = None,
437437
budget_type: str = None,

0 commit comments

Comments
 (0)