Skip to content

Commit db70a43

Browse files
[Feat] Better traditional pipeline cutoff time (#141)
* [Feat] Better traditional pipeline cutoff time * Fix unit testing * Better failure msg * bug fix catboost * Feedback from Ravin * First batch of feedback from comments * Missed examples * Syntax fix
1 parent ef6acf2 commit db70a43

File tree

15 files changed

+277
-141
lines changed

15 files changed

+277
-141
lines changed

autoPyTorch/api/base_task.py

Lines changed: 175 additions & 78 deletions
Large diffs are not rendered by default.

autoPyTorch/api/tabular_classification.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def search(
122122
budget_type: Optional[str] = None,
123123
budget: Optional[float] = None,
124124
total_walltime_limit: int = 100,
125-
func_eval_time_limit: int = 60,
126-
traditional_per_total_budget: float = 0.1,
125+
func_eval_time_limit_secs: Optional[int] = None,
126+
enable_traditional_pipeline: bool = True,
127127
memory_limit: Optional[int] = 4096,
128128
smac_scenario_args: Optional[Dict[str, Any]] = None,
129129
get_smac_object_callback: Optional[Callable] = None,
@@ -156,16 +156,24 @@ def search(
156156
in seconds for the search of appropriate models.
157157
By increasing this value, autopytorch has a higher
158158
chance of finding better models.
159-
func_eval_time_limit (int), (default=60): Time limit
159+
func_eval_time_limit_secs (int), (default=None): Time limit
160160
for a single call to the machine learning model.
161161
Model fitting will be terminated if the machine
162162
learning algorithm runs over the time limit. Set
163163
this value high enough so that typical machine
164164
learning algorithms can be fit on the training
165165
data.
166-
traditional_per_total_budget (float), (default=0.1):
167-
Percent of total walltime to be allocated for
168-
running traditional classifiers.
166+
When set to None, this time will automatically be set to
167+
total_walltime_limit // 2 to allow enough time to fit
168+
at least 2 individual machine learning algorithms.
169+
Set to np.inf in case no time limit is desired.
170+
enable_traditional_pipeline (bool), (default=True):
171+
We fit traditional machine learning algorithms
172+
(LightGBM, CatBoost, RandomForest, ExtraTrees, KNN, SVM)
173+
before building PyTorch Neural Networks. You can disable this
174+
feature by turning this flag to False. All machine learning
175+
algorithms that are fitted during search() are considered for
176+
ensemble building.
169177
memory_limit (Optional[int]), (default=4096): Memory
170178
limit in MB for the machine learning algorithm. autopytorch
171179
will stop fitting the machine learning algorithm if it tries
@@ -228,8 +236,8 @@ def search(
228236
budget_type=budget_type,
229237
budget=budget,
230238
total_walltime_limit=total_walltime_limit,
231-
func_eval_time_limit=func_eval_time_limit,
232-
traditional_per_total_budget=traditional_per_total_budget,
239+
func_eval_time_limit_secs=func_eval_time_limit_secs,
240+
enable_traditional_pipeline=enable_traditional_pipeline,
233241
memory_limit=memory_limit,
234242
smac_scenario_args=smac_scenario_args,
235243
get_smac_object_callback=get_smac_object_callback,

autoPyTorch/api/tabular_regression.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -103,26 +103,27 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An
103103
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline:
104104
return TabularRegressionPipeline(dataset_properties=dataset_properties)
105105

106-
def search(self,
107-
optimize_metric: str,
108-
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
109-
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
110-
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
111-
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
112-
dataset_name: Optional[str] = None,
113-
budget_type: Optional[str] = None,
114-
budget: Optional[float] = None,
115-
total_walltime_limit: int = 100,
116-
func_eval_time_limit: int = 60,
117-
traditional_per_total_budget: float = 0.1,
118-
memory_limit: Optional[int] = 4096,
119-
smac_scenario_args: Optional[Dict[str, Any]] = None,
120-
get_smac_object_callback: Optional[Callable] = None,
121-
all_supported_metrics: bool = True,
122-
precision: int = 32,
123-
disable_file_output: List = [],
124-
load_models: bool = True,
125-
) -> 'BaseTask':
106+
def search(
107+
self,
108+
optimize_metric: str,
109+
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
110+
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
111+
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
112+
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
113+
dataset_name: Optional[str] = None,
114+
budget_type: Optional[str] = None,
115+
budget: Optional[float] = None,
116+
total_walltime_limit: int = 100,
117+
func_eval_time_limit_secs: Optional[int] = None,
118+
enable_traditional_pipeline: bool = False,
119+
memory_limit: Optional[int] = 4096,
120+
smac_scenario_args: Optional[Dict[str, Any]] = None,
121+
get_smac_object_callback: Optional[Callable] = None,
122+
all_supported_metrics: bool = True,
123+
precision: int = 32,
124+
disable_file_output: List = [],
125+
load_models: bool = True,
126+
) -> 'BaseTask':
126127
"""
127128
Search for the best pipeline configuration for the given dataset.
128129
@@ -147,16 +148,20 @@ def search(self,
147148
in seconds for the search of appropriate models.
148149
By increasing this value, autopytorch has a higher
149150
chance of finding better models.
150-
func_eval_time_limit (int), (default=60): Time limit
151+
func_eval_time_limit_secs (int), (default=None): Time limit
151152
for a single call to the machine learning model.
152153
Model fitting will be terminated if the machine
153154
learning algorithm runs over the time limit. Set
154155
this value high enough so that typical machine
155156
learning algorithms can be fit on the training
156157
data.
157-
traditional_per_total_budget (float), (default=0.1):
158-
Percent of total walltime to be allocated for
159-
running traditional classifiers.
158+
When set to None, this time will automatically be set to
159+
total_walltime_limit // 2 to allow enough time to fit
160+
at least 2 individual machine learning algorithms.
161+
Set to np.inf in case no time limit is desired.
162+
enable_traditional_pipeline (bool), (default=False):
163+
Not enabled for regression. This flag is here to comply
164+
with the API.
160165
memory_limit (Optional[int]), (default=4096): Memory
161166
limit in MB for the machine learning algorithm. autopytorch
162167
will stop fitting the machine learning algorithm if it tries
@@ -219,8 +224,8 @@ def search(self,
219224
budget_type=budget_type,
220225
budget=budget,
221226
total_walltime_limit=total_walltime_limit,
222-
func_eval_time_limit=func_eval_time_limit,
223-
traditional_per_total_budget=traditional_per_total_budget,
227+
func_eval_time_limit_secs=func_eval_time_limit_secs,
228+
enable_traditional_pipeline=enable_traditional_pipeline,
224229
memory_limit=memory_limit,
225230
smac_scenario_args=smac_scenario_args,
226231
get_smac_object_callback=get_smac_object_callback,

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,9 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]:
377377
378378
"""
379379

380-
if not isinstance(self.configuration, Configuration):
381-
return {self.metric.name: self.metric._worst_possible_result}
380+
if isinstance(self.configuration, int):
381+
# We do not calculate performance of the dummy configurations
382+
return {self.metric.name: self.metric._optimum - self.metric._sign * self.metric._worst_possible_result}
382383

383384
if self.additional_metrics is not None:
384385
metrics = self.additional_metrics

autoPyTorch/evaluation/tae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def run(
442442

443443
empty_queue(queue)
444444
self.logger.debug(
445-
'Finished function evaluation. Status: %s, Cost: %f, Runtime: %f, Additional %s',
446-
status, cost, runtime, additional_run_info,
445+
'Finished function evaluation %s. Status: %s, Cost: %f, Runtime: %f, Additional %s',
446+
str(num_run), status, cost, runtime, additional_run_info,
447447
)
448448
return status, cost, runtime, additional_run_info

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def fit_predict_and_loss(self) -> None:
110110

111111
status = StatusType.SUCCESS
112112

113+
self.logger.debug("In train evaluator fit_predict_and_loss, num_run: {} loss:{}".format(
114+
self.num_run,
115+
loss
116+
))
113117
self.finish_up(
114118
loss=loss,
115119
train_loss=train_loss,
@@ -236,7 +240,10 @@ def fit_predict_and_loss(self) -> None:
236240
self.pipeline = self._get_pipeline()
237241

238242
status = StatusType.SUCCESS
239-
self.logger.debug("In train evaluator fit_predict_and_loss, loss:{}".format(opt_loss))
243+
self.logger.debug("In train evaluator fit_predict_and_loss, num_run: {} loss:{}".format(
244+
self.num_run,
245+
opt_loss
246+
))
240247
self.finish_up(
241248
loss=opt_loss,
242249
train_loss=train_loss,

autoPyTorch/optimizer/smbo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self,
8484
dataset_name: str,
8585
backend: Backend,
8686
total_walltime_limit: float,
87-
func_eval_time_limit: float,
87+
func_eval_time_limit_secs: float,
8888
memory_limit: typing.Optional[int],
8989
metric: autoPyTorchMetric,
9090
watcher: StopWatch,
@@ -120,7 +120,7 @@ def __init__(self,
120120
An interface with disk
121121
total_walltime_limit (float):
122122
The maximum allowed time for this job
123-
func_eval_time_limit (float):
123+
func_eval_time_limit_secs (float):
124124
How much each individual task is allowed to last
125125
memory_limit (typing.Optional[int]):
126126
Maximum allowed CPU memory this task can use
@@ -180,7 +180,7 @@ def __init__(self,
180180
# and a bunch of useful limits
181181
self.worst_possible_result = get_cost_of_crash(self.metric)
182182
self.total_walltime_limit = int(total_walltime_limit)
183-
self.func_eval_time_limit = int(func_eval_time_limit)
183+
self.func_eval_time_limit_secs = int(func_eval_time_limit_secs)
184184
self.memory_limit = memory_limit
185185
self.watcher = watcher
186186
self.seed = seed
@@ -265,7 +265,7 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None
265265
scenario_dict = {
266266
'abort_on_first_run_crash': False,
267267
'cs': self.config_space,
268-
'cutoff_time': self.func_eval_time_limit,
268+
'cutoff_time': self.func_eval_time_limit_secs,
269269
'deterministic': 'true',
270270
'instances': instances,
271271
'memory_limit': self.memory_limit,

autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,22 @@
1313
SVMModel)
1414

1515
_classifiers = {
16+
# Sort by more robust models
17+
# Depending on the allocated time budget, only the
18+
# top models from this dict are two be fitted.
19+
# LGBM is the more robust model, with
20+
# internal measures to prevent crashes, overfit
21+
# Additionally, it is one of the state of the art
22+
# methods for tabular prediction.
23+
# Then follow with catboost for categorical heavy
24+
# datasets. The other models are complementary and
25+
# their ordering is not critical
26+
'lgb': LGBModel,
1627
'catboost': CatboostModel,
28+
'random_forest': RFModel,
1729
'extra_trees': ExtraTreesModel,
30+
'svm_classifier': SVMModel,
1831
'knn_classifier': KNNModel,
19-
'lgb': LGBModel,
20-
'random_forest': RFModel,
21-
'svm_classifier': SVMModel
2232
}
2333
_addons = ThirdPartyComponents(BaseClassifier)
2434

autoPyTorch/pipeline/traditional_tabular_classification.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,13 @@ def get_pipeline_representation(self) -> Dict[str, str]:
211211
"""
212212
estimator_name = 'TraditionalTabularClassification'
213213
if self.steps[0][1].choice is not None:
214-
estimator_name = cast(str,
215-
self.steps[0][1].choice.model.get_properties()['shortname'])
214+
if self.steps[0][1].choice.model is None:
215+
estimator_name = self.steps[0][1].choice.__class__.__name__
216+
else:
217+
estimator_name = cast(
218+
str,
219+
self.steps[0][1].choice.model.get_properties()['shortname']
220+
)
216221
return {
217222
'Preprocessing': 'None',
218223
'Estimator': estimator_name,

examples/tabular/20_basics/example_tabular_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
y_test=y_test.copy(),
5858
optimize_metric='accuracy',
5959
total_walltime_limit=300,
60-
func_eval_time_limit=50
60+
func_eval_time_limit_secs=50
6161
)
6262

6363
############################################################################

0 commit comments

Comments
 (0)