Skip to content

Commit b48b952

Browse files
authored
Merge pull request #89 from franchuterivera/InputValidator
Handling Input to auto pytorch
2 parents d33388b + a717d60 commit b48b952

33 files changed

+2746
-470
lines changed

autoPyTorch/api/base_task.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import typing
1111
import unittest.mock
12+
import uuid
1213
import warnings
1314
from abc import abstractmethod
1415
from typing import Any, Callable, Dict, List, Optional, Union, cast
@@ -122,21 +123,24 @@ class BaseTask:
122123
"""
123124

124125
def __init__(
125-
self,
126-
seed: int = 1,
127-
n_jobs: int = 1,
128-
logging_config: Optional[Dict] = None,
129-
ensemble_size: int = 50,
130-
ensemble_nbest: int = 50,
131-
max_models_on_disc: int = 50,
132-
temporary_directory: Optional[str] = None,
133-
output_directory: Optional[str] = None,
134-
delete_tmp_folder_after_terminate: bool = True,
135-
delete_output_folder_after_terminate: bool = True,
136-
include_components: Optional[Dict] = None,
137-
exclude_components: Optional[Dict] = None,
138-
backend: Optional[Backend] = None,
139-
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
126+
self,
127+
seed: int = 1,
128+
n_jobs: int = 1,
129+
logging_config: Optional[Dict] = None,
130+
ensemble_size: int = 50,
131+
ensemble_nbest: int = 50,
132+
max_models_on_disc: int = 50,
133+
temporary_directory: Optional[str] = None,
134+
output_directory: Optional[str] = None,
135+
delete_tmp_folder_after_terminate: bool = True,
136+
delete_output_folder_after_terminate: bool = True,
137+
include_components: Optional[Dict] = None,
138+
exclude_components: Optional[Dict] = None,
139+
backend: Optional[Backend] = None,
140+
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
141+
resampling_strategy_args: Optional[Dict[str, Any]] = None,
142+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
143+
task_type: Optional[str] = None
140144
) -> None:
141145
self.seed = seed
142146
self.n_jobs = n_jobs
@@ -157,14 +161,14 @@ def __init__(
157161
delete_tmp_folder_after_terminate=delete_tmp_folder_after_terminate,
158162
delete_output_folder_after_terminate=delete_output_folder_after_terminate,
159163
)
164+
self.task_type = task_type
160165
self._stopwatch = StopWatch()
161166

162167
self.pipeline_options = replace_string_bool_to_bool(json.load(open(
163168
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
164169

165170
self.search_space: Optional[ConfigurationSpace] = None
166171
self._dataset_requirements: Optional[List[FitRequirement]] = None
167-
self.task_type: Optional[str] = None
168172
self._metric: Optional[autoPyTorchMetric] = None
169173
self._logger: Optional[PicklableClientLogger] = None
170174
self.run_history: Optional[RunHistory] = None
@@ -176,7 +180,8 @@ def __init__(
176180
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
177181

178182
# Store the resampling strategy from the dataset, to load models as needed
179-
self.resampling_strategy = None # type: Optional[Union[CrossValTypes, HoldoutValTypes]]
183+
self.resampling_strategy = resampling_strategy
184+
self.resampling_strategy_args = resampling_strategy_args
180185

181186
self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
182187

@@ -287,7 +292,7 @@ def _get_logger(self, name: str) -> PicklableClientLogger:
287292
output_dir=self._backend.temporary_directory,
288293
)
289294

290-
# As Auto-sklearn works with distributed process,
295+
# As AutoPyTorch works with distributed process,
291296
# we implement a logger server that can receive tcp
292297
# pickled messages. They are unpickled and processed locally
293298
# under the above logging configuration setting
@@ -398,20 +403,16 @@ def _close_dask_client(self) -> None:
398403
self._is_dask_client_internally_created = False
399404
del self._is_dask_client_internally_created
400405

401-
def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]]
402-
) -> bool:
406+
def _load_models(self) -> bool:
403407

404408
"""
405409
Loads the models saved in the temporary directory
406410
during the smac run and the final ensemble created
407-
Args:
408-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): resampling strategy used to split the data
409-
and to validate the performance of a candidate pipeline
410411
411412
Returns:
412413
None
413414
"""
414-
if resampling_strategy is None:
415+
if self.resampling_strategy is None:
415416
raise ValueError("Resampling strategy is needed to determine what models to load")
416417
self.ensemble_ = self._backend.load_ensemble(self.seed)
417418

@@ -422,10 +423,10 @@ def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, Holdou
422423
if self.ensemble_:
423424
identifiers = self.ensemble_.get_selected_model_identifiers()
424425
self.models_ = self._backend.load_models_by_identifiers(identifiers)
425-
if isinstance(resampling_strategy, CrossValTypes):
426+
if isinstance(self.resampling_strategy, CrossValTypes):
426427
self.cv_models_ = self._backend.load_cv_models_by_identifiers(identifiers)
427428

428-
if isinstance(resampling_strategy, CrossValTypes):
429+
if isinstance(self.resampling_strategy, CrossValTypes):
429430
if len(self.cv_models_) == 0:
430431
raise ValueError('No models fitted!')
431432

@@ -610,10 +611,10 @@ def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) ->
610611
)
611612
return num_run
612613

613-
def search(
614+
def _search(
614615
self,
615-
dataset: BaseDataset,
616616
optimize_metric: str,
617+
dataset: BaseDataset,
617618
budget_type: Optional[str] = None,
618619
budget: Optional[float] = None,
619620
total_walltime_limit: int = 100,
@@ -638,6 +639,7 @@ def search(
638639
The argument that will provide the dataset splits. It is
639640
a subclass of the base dataset object which can
640641
generate the splits based on different restrictions.
642+
Providing X_train, y_train and dataset together is not supported.
641643
optimize_metric (str): name of the metric that is used to
642644
evaluate a pipeline.
643645
budget_type (Optional[str]):
@@ -692,6 +694,7 @@ def search(
692694
self
693695
694696
"""
697+
695698
if self.task_type != dataset.task_type:
696699
raise ValueError("Incompatible dataset entered for current task,"
697700
"expected dataset to have task type :{} got "
@@ -705,8 +708,8 @@ def search(
705708
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
706709
self._stopwatch.start_task(experiment_task_name)
707710
self.dataset_name = dataset.dataset_name
708-
self.resampling_strategy = dataset.resampling_strategy
709-
self._logger = self._get_logger(self.dataset_name)
711+
if self._logger is None:
712+
self._logger = self._get_logger(self.dataset_name)
710713
self._all_supported_metrics = all_supported_metrics
711714
self._disable_file_output = disable_file_output
712715
self._memory_limit = memory_limit
@@ -869,7 +872,7 @@ def search(
869872

870873
if load_models:
871874
self._logger.info("Loading models...")
872-
self._load_models(dataset.resampling_strategy)
875+
self._load_models()
873876
self._logger.info("Finished loading models...")
874877

875878
# Clean up the logger
@@ -906,8 +909,11 @@ def refit(
906909
Returns:
907910
self
908911
"""
912+
if self.dataset_name is None:
913+
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
909914

910-
self._logger = self._get_logger(dataset.dataset_name)
915+
if self._logger is None:
916+
self._logger = self._get_logger(self.dataset_name)
911917

912918
dataset_requirements = get_dataset_requirements(
913919
info=self._get_required_dataset_properties(dataset))
@@ -927,7 +933,7 @@ def refit(
927933
})
928934
X.update({**self.pipeline_options, **budget_config})
929935
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
930-
self._load_models(dataset.resampling_strategy)
936+
self._load_models()
931937

932938
# Refit is not applicable when ensemble_size is set to zero.
933939
if self.ensemble_ is None:
@@ -973,7 +979,11 @@ def fit(self,
973979
Returns:
974980
(BasePipeline): fitted pipeline
975981
"""
976-
self._logger = self._get_logger(dataset.dataset_name)
982+
if self.dataset_name is None:
983+
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
984+
985+
if self._logger is None:
986+
self._logger = self._get_logger(self.dataset_name)
977987

978988
# get dataset properties
979989
dataset_requirements = get_dataset_requirements(
@@ -1025,7 +1035,7 @@ def predict(
10251035
if self._logger is None:
10261036
self._logger = self._get_logger("Predict-Logger")
10271037

1028-
if self.ensemble_ is None and not self._load_models(self.resampling_strategy):
1038+
if self.ensemble_ is None and not self._load_models():
10291039
raise ValueError("No ensemble found. Either fit has not yet "
10301040
"been called or no ensemble was fitted")
10311041

@@ -1084,9 +1094,6 @@ def score(
10841094
Returns:
10851095
Dict[str, float]: Value of the evaluation metric calculated on the test set.
10861096
"""
1087-
if isinstance(y_test, pd.Series):
1088-
y_test = y_test.to_numpy(dtype=np.float)
1089-
10901097
if self._metric is None:
10911098
raise ValueError("No metric found. Either fit/search has not been called yet "
10921099
"or AutoPyTorch failed to infer a metric from the dataset ")

0 commit comments

Comments
 (0)