9
9
import time
10
10
import typing
11
11
import unittest .mock
12
+ import uuid
12
13
import warnings
13
14
from abc import abstractmethod
14
15
from typing import Any , Callable , Dict , List , Optional , Union , cast
@@ -122,21 +123,24 @@ class BaseTask:
122
123
"""
123
124
124
125
def __init__ (
125
- self ,
126
- seed : int = 1 ,
127
- n_jobs : int = 1 ,
128
- logging_config : Optional [Dict ] = None ,
129
- ensemble_size : int = 50 ,
130
- ensemble_nbest : int = 50 ,
131
- max_models_on_disc : int = 50 ,
132
- temporary_directory : Optional [str ] = None ,
133
- output_directory : Optional [str ] = None ,
134
- delete_tmp_folder_after_terminate : bool = True ,
135
- delete_output_folder_after_terminate : bool = True ,
136
- include_components : Optional [Dict ] = None ,
137
- exclude_components : Optional [Dict ] = None ,
138
- backend : Optional [Backend ] = None ,
139
- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None
126
+ self ,
127
+ seed : int = 1 ,
128
+ n_jobs : int = 1 ,
129
+ logging_config : Optional [Dict ] = None ,
130
+ ensemble_size : int = 50 ,
131
+ ensemble_nbest : int = 50 ,
132
+ max_models_on_disc : int = 50 ,
133
+ temporary_directory : Optional [str ] = None ,
134
+ output_directory : Optional [str ] = None ,
135
+ delete_tmp_folder_after_terminate : bool = True ,
136
+ delete_output_folder_after_terminate : bool = True ,
137
+ include_components : Optional [Dict ] = None ,
138
+ exclude_components : Optional [Dict ] = None ,
139
+ backend : Optional [Backend ] = None ,
140
+ resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
141
+ resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
142
+ search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
143
+ task_type : Optional [str ] = None
140
144
) -> None :
141
145
self .seed = seed
142
146
self .n_jobs = n_jobs
@@ -157,14 +161,14 @@ def __init__(
157
161
delete_tmp_folder_after_terminate = delete_tmp_folder_after_terminate ,
158
162
delete_output_folder_after_terminate = delete_output_folder_after_terminate ,
159
163
)
164
+ self .task_type = task_type
160
165
self ._stopwatch = StopWatch ()
161
166
162
167
self .pipeline_options = replace_string_bool_to_bool (json .load (open (
163
168
os .path .join (os .path .dirname (__file__ ), '../configs/default_pipeline_options.json' ))))
164
169
165
170
self .search_space : Optional [ConfigurationSpace ] = None
166
171
self ._dataset_requirements : Optional [List [FitRequirement ]] = None
167
- self .task_type : Optional [str ] = None
168
172
self ._metric : Optional [autoPyTorchMetric ] = None
169
173
self ._logger : Optional [PicklableClientLogger ] = None
170
174
self .run_history : Optional [RunHistory ] = None
@@ -176,7 +180,8 @@ def __init__(
176
180
self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
177
181
178
182
# Store the resampling strategy from the dataset, to load models as needed
179
- self .resampling_strategy = None # type: Optional[Union[CrossValTypes, HoldoutValTypes]]
183
+ self .resampling_strategy = resampling_strategy
184
+ self .resampling_strategy_args = resampling_strategy_args
180
185
181
186
self .stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
182
187
@@ -287,7 +292,7 @@ def _get_logger(self, name: str) -> PicklableClientLogger:
287
292
output_dir = self ._backend .temporary_directory ,
288
293
)
289
294
290
- # As Auto-sklearn works with distributed process,
295
+ # As AutoPyTorch works with distributed process,
291
296
# we implement a logger server that can receive tcp
292
297
# pickled messages. They are unpickled and processed locally
293
298
# under the above logging configuration setting
@@ -398,20 +403,16 @@ def _close_dask_client(self) -> None:
398
403
self ._is_dask_client_internally_created = False
399
404
del self ._is_dask_client_internally_created
400
405
401
- def _load_models (self , resampling_strategy : Optional [Union [CrossValTypes , HoldoutValTypes ]]
402
- ) -> bool :
406
+ def _load_models (self ) -> bool :
403
407
404
408
"""
405
409
Loads the models saved in the temporary directory
406
410
during the smac run and the final ensemble created
407
- Args:
408
- resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): resampling strategy used to split the data
409
- and to validate the performance of a candidate pipeline
410
411
411
412
Returns:
412
413
None
413
414
"""
414
- if resampling_strategy is None :
415
+ if self . resampling_strategy is None :
415
416
raise ValueError ("Resampling strategy is needed to determine what models to load" )
416
417
self .ensemble_ = self ._backend .load_ensemble (self .seed )
417
418
@@ -422,10 +423,10 @@ def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, Holdou
422
423
if self .ensemble_ :
423
424
identifiers = self .ensemble_ .get_selected_model_identifiers ()
424
425
self .models_ = self ._backend .load_models_by_identifiers (identifiers )
425
- if isinstance (resampling_strategy , CrossValTypes ):
426
+ if isinstance (self . resampling_strategy , CrossValTypes ):
426
427
self .cv_models_ = self ._backend .load_cv_models_by_identifiers (identifiers )
427
428
428
- if isinstance (resampling_strategy , CrossValTypes ):
429
+ if isinstance (self . resampling_strategy , CrossValTypes ):
429
430
if len (self .cv_models_ ) == 0 :
430
431
raise ValueError ('No models fitted!' )
431
432
@@ -610,10 +611,10 @@ def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) ->
610
611
)
611
612
return num_run
612
613
613
- def search (
614
+ def _search (
614
615
self ,
615
- dataset : BaseDataset ,
616
616
optimize_metric : str ,
617
+ dataset : BaseDataset ,
617
618
budget_type : Optional [str ] = None ,
618
619
budget : Optional [float ] = None ,
619
620
total_walltime_limit : int = 100 ,
@@ -638,6 +639,7 @@ def search(
638
639
The argument that will provide the dataset splits. It is
639
640
a subclass of the base dataset object which can
640
641
generate the splits based on different restrictions.
642
+ Providing X_train, y_train and dataset together is not supported.
641
643
optimize_metric (str): name of the metric that is used to
642
644
evaluate a pipeline.
643
645
budget_type (Optional[str]):
@@ -692,6 +694,7 @@ def search(
692
694
self
693
695
694
696
"""
697
+
695
698
if self .task_type != dataset .task_type :
696
699
raise ValueError ("Incompatible dataset entered for current task,"
697
700
"expected dataset to have task type :{} got "
@@ -705,8 +708,8 @@ def search(
705
708
dataset_properties = dataset .get_dataset_properties (dataset_requirements )
706
709
self ._stopwatch .start_task (experiment_task_name )
707
710
self .dataset_name = dataset .dataset_name
708
- self .resampling_strategy = dataset . resampling_strategy
709
- self ._logger = self ._get_logger (self .dataset_name )
711
+ if self ._logger is None :
712
+ self ._logger = self ._get_logger (self .dataset_name )
710
713
self ._all_supported_metrics = all_supported_metrics
711
714
self ._disable_file_output = disable_file_output
712
715
self ._memory_limit = memory_limit
@@ -869,7 +872,7 @@ def search(
869
872
870
873
if load_models :
871
874
self ._logger .info ("Loading models..." )
872
- self ._load_models (dataset . resampling_strategy )
875
+ self ._load_models ()
873
876
self ._logger .info ("Finished loading models..." )
874
877
875
878
# Clean up the logger
@@ -906,8 +909,11 @@ def refit(
906
909
Returns:
907
910
self
908
911
"""
912
+ if self .dataset_name is None :
913
+ self .dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
909
914
910
- self ._logger = self ._get_logger (dataset .dataset_name )
915
+ if self ._logger is None :
916
+ self ._logger = self ._get_logger (self .dataset_name )
911
917
912
918
dataset_requirements = get_dataset_requirements (
913
919
info = self ._get_required_dataset_properties (dataset ))
@@ -927,7 +933,7 @@ def refit(
927
933
})
928
934
X .update ({** self .pipeline_options , ** budget_config })
929
935
if self .models_ is None or len (self .models_ ) == 0 or self .ensemble_ is None :
930
- self ._load_models (dataset . resampling_strategy )
936
+ self ._load_models ()
931
937
932
938
# Refit is not applicable when ensemble_size is set to zero.
933
939
if self .ensemble_ is None :
@@ -973,7 +979,11 @@ def fit(self,
973
979
Returns:
974
980
(BasePipeline): fitted pipeline
975
981
"""
976
- self ._logger = self ._get_logger (dataset .dataset_name )
982
+ if self .dataset_name is None :
983
+ self .dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
984
+
985
+ if self ._logger is None :
986
+ self ._logger = self ._get_logger (self .dataset_name )
977
987
978
988
# get dataset properties
979
989
dataset_requirements = get_dataset_requirements (
@@ -1025,7 +1035,7 @@ def predict(
1025
1035
if self ._logger is None :
1026
1036
self ._logger = self ._get_logger ("Predict-Logger" )
1027
1037
1028
- if self .ensemble_ is None and not self ._load_models (self . resampling_strategy ):
1038
+ if self .ensemble_ is None and not self ._load_models ():
1029
1039
raise ValueError ("No ensemble found. Either fit has not yet "
1030
1040
"been called or no ensemble was fitted" )
1031
1041
@@ -1084,9 +1094,6 @@ def score(
1084
1094
Returns:
1085
1095
Dict[str, float]: Value of the evaluation metric calculated on the test set.
1086
1096
"""
1087
- if isinstance (y_test , pd .Series ):
1088
- y_test = y_test .to_numpy (dtype = np .float )
1089
-
1090
1097
if self ._metric is None :
1091
1098
raise ValueError ("No metric found. Either fit/search has not been called yet "
1092
1099
"or AutoPyTorch failed to infer a metric from the dataset " )
0 commit comments