Skip to content

Commit

Permalink
[ADD] Robustly refit models in final ensemble in parallel (#471)
Browse files Browse the repository at this point in the history
* add parallel model runner and update running traditional classifiers

* update pipeline config to pipeline options

* working refit function

* fix mypy and flake

* suggestions from review

* fix mypy and flake

* suggestions from review

* finish documentation

* fix tests

* add test for parallel model runner

* fix flake

* fix tests

* fix traditional prediction for refit

* suggestions from review

* add warning for failed processing of results

* remove unnecessary change

* update autopytorch version number

* update autopytorch version number and the example file
  • Loading branch information
ravinkohli authored Aug 23, 2022
1 parent d160903 commit ce78f89
Show file tree
Hide file tree
Showing 23 changed files with 909 additions and 276 deletions.
2 changes: 1 addition & 1 deletion autoPyTorch/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.2"
__version__ = "0.2.1"
419 changes: 262 additions & 157 deletions autoPyTorch/api/base_task.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions autoPyTorch/ensemble/abstract_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
class AbstractEnsemble(object):
__metaclass__ = ABCMeta

def __init__(self):
self.identifiers_: List[Tuple[int, int, float]] = []

@abstractmethod
def fit(
self,
Expand Down Expand Up @@ -76,3 +79,12 @@ def get_validation_performance(self) -> float:
Returns:
Score
"""

def update_identifiers(
self,
replace_identifiers_mapping: Dict[Tuple[int, int, float], Tuple[int, int, float]]
) -> None:
identifiers = self.identifiers_.copy()
for i, identifier in enumerate(self.identifiers_):
identifiers[i] = replace_identifiers_mapping.get(identifier, identifier)
self.identifiers_ = identifiers
28 changes: 17 additions & 11 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def get_additional_run_info(self) -> Dict[str, Any]:
Can be found in autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs
"""
return {'pipeline_configuration': self.configuration,
'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config()}
'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config(),
'configuration_origin': 'traditional'}

def get_pipeline_representation(self) -> Dict[str, str]:
return self.pipeline.get_pipeline_representation()
Expand Down Expand Up @@ -347,7 +348,7 @@ class AbstractEvaluator(object):
An evaluator is an object that:
+ constructs a pipeline (i.e. a classification or regression estimator) for a given
pipeline_config and run settings (budget, seed)
pipeline_options and run settings (budget, seed)
+ Fits and trains this pipeline (TrainEvaluator) or tests a given
configuration (TestEvaluator)
Expand All @@ -369,7 +370,7 @@ class AbstractEvaluator(object):
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type. Currently, only epoch and time are allowed.
pipeline_config (Optional[Dict[str, Any]]):
pipeline_options (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand Down Expand Up @@ -430,7 +431,7 @@ def __init__(self, backend: Backend,
budget: float,
configuration: Union[int, str, Configuration],
budget_type: str = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
seed: int = 1,
output_y_hat_optimization: bool = True,
num_run: Optional[int] = None,
Expand Down Expand Up @@ -523,10 +524,10 @@ def __init__(self, backend: Backend,
self._init_params = init_params

assert self.pipeline_class is not None, "Could not infer pipeline class"
pipeline_config = pipeline_config if pipeline_config is not None \
pipeline_options = pipeline_options if pipeline_options is not None \
else self.pipeline_class.get_default_pipeline_options()
self.budget_type = pipeline_config['budget_type'] if budget_type is None else budget_type
self.budget = pipeline_config[self.budget_type] if budget == 0 else budget
self.budget_type = pipeline_options['budget_type'] if budget_type is None else budget_type
self.budget = pipeline_options[self.budget_type] if budget == 0 else budget

self.num_run = 0 if num_run is None else num_run

Expand All @@ -539,7 +540,7 @@ def __init__(self, backend: Backend,
port=logger_port,
)

self._init_fit_dictionary(logger_port=logger_port, pipeline_config=pipeline_config, metrics_dict=metrics_dict)
self._init_fit_dictionary(logger_port=logger_port, pipeline_options=pipeline_options, metrics_dict=metrics_dict)
self.Y_optimization: Optional[np.ndarray] = None
self.Y_actual_train: Optional[np.ndarray] = None
self.pipelines: Optional[List[BaseEstimator]] = None
Expand Down Expand Up @@ -597,7 +598,7 @@ def _init_datamanager_info(
def _init_fit_dictionary(
self,
logger_port: int,
pipeline_config: Dict[str, Any],
pipeline_options: Dict[str, Any],
metrics_dict: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Expand All @@ -608,7 +609,7 @@ def _init_fit_dictionary(
Logging is performed using a socket-server scheme to be robust against many
parallel entities that want to write to the same file. This integer states the
socket port for the communication channel.
pipeline_config (Dict[str, Any]):
pipeline_options (Dict[str, Any]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand All @@ -634,7 +635,7 @@ def _init_fit_dictionary(
'optimize_metric': self.metric.name
})

self.fit_dictionary.update(pipeline_config)
self.fit_dictionary.update(pipeline_options)
# If the budget is epochs, we want to limit that in the fit dictionary
if self.budget_type == 'epochs':
self.fit_dictionary['epochs'] = self.budget
Expand Down Expand Up @@ -805,6 +806,11 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
if test_loss is not None:
additional_run_info['test_loss'] = test_loss

# Add information to additional info that can be useful for other functionalities
additional_run_info['configuration'] = self.configuration \
if not isinstance(self.configuration, Configuration) else self.configuration.get_dictionary()
additional_run_info['budget'] = self.budget

rval_dict = {'loss': cost,
'additional_run_info': additional_run_info,
'status': status}
Expand Down
18 changes: 9 additions & 9 deletions autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
abort_on_first_run_crash: bool,
pynisher_context: str,
multi_objectives: List[str],
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
initial_num_run: int = 1,
stats: Optional[Stats] = None,
run_obj: str = 'quality',
Expand Down Expand Up @@ -198,13 +198,13 @@ def __init__(
self.disable_file_output = disable_file_output
self.init_params = init_params

self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
self.budget_type = pipeline_options['budget_type'] if pipeline_options is not None else budget_type

self.pipeline_config: Dict[str, Union[int, str, float]] = dict()
if pipeline_config is None:
pipeline_config = replace_string_bool_to_bool(json.load(open(
self.pipeline_options: Dict[str, Union[int, str, float]] = dict()
if pipeline_options is None:
pipeline_options = replace_string_bool_to_bool(json.load(open(
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
self.pipeline_config.update(pipeline_config)
self.pipeline_options.update(pipeline_options)

self.logger_port = logger_port
if self.logger_port is None:
Expand All @@ -225,7 +225,7 @@ def __init__(
def _check_and_get_default_budget(self) -> float:
budget_type_choices_tabular = ('epochs', 'runtime')
budget_choices = {
budget_type: float(self.pipeline_config.get(budget_type, np.inf))
budget_type: float(self.pipeline_options.get(budget_type, np.inf))
for budget_type in budget_type_choices_tabular
}

Expand All @@ -234,7 +234,7 @@ def _check_and_get_default_budget(self) -> float:
budget_type_choices = budget_type_choices_tabular + FORECASTING_BUDGET_TYPE

# budget is defined by epochs by default
budget_type = str(self.pipeline_config.get('budget_type', 'epochs'))
budget_type = str(self.pipeline_options.get('budget_type', 'epochs'))
if self.budget_type is not None:
budget_type = self.budget_type

Expand Down Expand Up @@ -361,7 +361,7 @@ def run(
init_params=init_params,
budget=budget,
budget_type=self.budget_type,
pipeline_config=self.pipeline_config,
pipeline_options=self.pipeline_options,
logger_port=self.logger_port,
all_supported_metrics=self.all_supported_metrics,
search_space_updates=self.search_space_updates
Expand Down
10 changes: 5 additions & 5 deletions autoPyTorch/evaluation/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TestEvaluator(AbstractEvaluator):
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type, which can be epochs or time
pipeline_config (Optional[Dict[str, Any]]):
pipeline_options (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
budget: float,
configuration: Union[int, str, Configuration],
budget_type: str = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
seed: int = 1,
output_y_hat_optimization: bool = False,
num_run: Optional[int] = None,
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config,
pipeline_options=pipeline_options,
search_space_updates=search_space_updates
)

Expand Down Expand Up @@ -206,7 +206,7 @@ def eval_test_function(
include: Optional[Dict[str, Any]],
exclude: Optional[Dict[str, Any]],
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
budget_type: str = None,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
Expand All @@ -230,7 +230,7 @@ def eval_test_function(
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config,
pipeline_options=pipeline_options,
search_space_updates=search_space_updates)

evaluator.fit_predict_and_loss()
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TimeSeriesForecastingTrainEvaluator(TrainEvaluator):
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type, which can be epochs or time
pipeline_config (Optional[Dict[str, Any]]):
pipeline_options (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, backend: Backend, queue: Queue,
metric: autoPyTorchMetric,
budget: float,
budget_type: str = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
configuration: Optional[Configuration] = None,
seed: int = 1,
output_y_hat_optimization: bool = True,
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(self, backend: Backend, queue: Queue,
logger_port=logger_port,
keep_models=keep_models,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config,
pipeline_options=pipeline_options,
search_space_updates=search_space_updates
)
self.datamanager = backend.load_datamanager()
Expand Down Expand Up @@ -456,7 +456,7 @@ def forecasting_eval_train_function(
include: Optional[Dict[str, Any]],
exclude: Optional[Dict[str, Any]],
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
budget_type: str = None,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
Expand Down Expand Up @@ -490,7 +490,7 @@ def forecasting_eval_train_function(
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type, which can be epochs or time
pipeline_config (Optional[Dict[str, Any]]):
pipeline_options (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand Down Expand Up @@ -550,7 +550,7 @@ def forecasting_eval_train_function(
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config,
pipeline_options=pipeline_options,
search_space_updates=search_space_updates,
max_budget=max_budget,
min_num_test_instances=min_num_test_instances,
Expand Down
12 changes: 6 additions & 6 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TrainEvaluator(AbstractEvaluator):
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type, which can be epochs or time
pipeline_config (Optional[Dict[str, Any]]):
pipeline_options (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, backend: Backend, queue: Queue,
budget: float,
configuration: Union[int, str, Configuration],
budget_type: str = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
seed: int = 1,
output_y_hat_optimization: bool = True,
num_run: Optional[int] = None,
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(self, backend: Backend, queue: Queue,
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config,
pipeline_options=pipeline_options,
search_space_updates=search_space_updates
)

Expand Down Expand Up @@ -420,7 +420,7 @@ def eval_train_function(
include: Optional[Dict[str, Any]],
exclude: Optional[Dict[str, Any]],
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
pipeline_config: Optional[Dict[str, Any]] = None,
pipeline_options: Optional[Dict[str, Any]] = None,
budget_type: str = None,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
Expand Down Expand Up @@ -452,7 +452,7 @@ def eval_train_function(
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type, which can be epochs or time
pipeline_config (Optional[Dict[str, Any]]):
pipeline_options (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
Expand Down Expand Up @@ -506,7 +506,7 @@ def eval_train_function(
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config,
pipeline_options=pipeline_options,
search_space_updates=search_space_updates,
)
evaluator.fit_predict_and_loss()
8 changes: 4 additions & 4 deletions autoPyTorch/optimizer/smbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(self,
watcher: StopWatch,
n_jobs: int,
dask_client: Optional[dask.distributed.Client],
pipeline_config: Dict[str, Any],
pipeline_options: Dict[str, Any],
start_num_run: int = 1,
seed: int = 1,
resampling_strategy: Union[HoldoutValTypes,
Expand Down Expand Up @@ -227,7 +227,7 @@ def __init__(self,
self.backend = backend
self.all_supported_metrics = all_supported_metrics

self.pipeline_config = pipeline_config
self.pipeline_options = pipeline_options
# the configuration space
self.config_space = config_space

Expand Down Expand Up @@ -326,7 +326,7 @@ def run_smbo(self, func: Optional[Callable] = None
ta=func,
logger_port=self.logger_port,
all_supported_metrics=self.all_supported_metrics,
pipeline_config=self.pipeline_config,
pipeline_options=self.pipeline_options,
search_space_updates=self.search_space_updates,
pynisher_context=self.pynisher_context,
)
Expand Down Expand Up @@ -376,7 +376,7 @@ def run_smbo(self, func: Optional[Callable] = None
)
scenario_dict.update(self.smac_scenario_args)

budget_type = self.pipeline_config['budget_type']
budget_type = self.pipeline_options['budget_type']
if budget_type in FORECASTING_BUDGET_TYPE:
if STRING_TO_TASK_TYPES.get(self.task_type, -1) != TIMESERIES_FORECASTING:
raise ValueError('Forecasting Budget type is only available for forecasting task!')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def __init__(
self.add_fit_requirements([
FitRequirement('X_train', (np.ndarray, list, pd.DataFrame), user_defined=False, dataset_property=False),
FitRequirement('y_train', (np.ndarray, list, pd.Series,), user_defined=False, dataset_property=False),
FitRequirement('train_indices', (np.ndarray, list), user_defined=False, dataset_property=False),
FitRequirement('val_indices', (np.ndarray, list), user_defined=False, dataset_property=False)])
FitRequirement('train_indices', (np.ndarray, list), user_defined=False, dataset_property=False)])

def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent:
"""
Expand Down Expand Up @@ -90,8 +89,14 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent:

# train model
blockPrint()
val_indices = X.get('val_indices', None)
X_val = None
y_val = None
if val_indices is not None:
X_val = X['X_train'][val_indices]
y_val = X['y_train'][val_indices]
self.fit_output = self.model.fit(X['X_train'][X['train_indices']], X['y_train'][X['train_indices']],
X['X_train'][X['val_indices']], X['y_train'][X['val_indices']])
X_val, y_val)
enablePrint()

# infer
Expand Down
Loading

0 comments on commit ce78f89

Please sign in to comment.