Skip to content

move to a minimization problem #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Run tests
run: |
if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi
python -m pytest --durations=20 --timeout=300 --timeout-method=thread -v $codecov test
python -m pytest --durations=20 --timeout=600 --timeout-method=signal -v $codecov test
- name: Check for files left behind by test
if: ${{ always() }}
run: |
Expand Down
208 changes: 104 additions & 104 deletions autoPyTorch/ensemble/ensemble_builder.py

Large diffs are not rendered by default.

132 changes: 99 additions & 33 deletions autoPyTorch/ensemble/ensemble_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from autoPyTorch.ensemble.abstract_ensemble import AbstractEnsemble
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_loss


class EnsembleSelection(AbstractEnsemble):
Expand Down Expand Up @@ -39,6 +39,24 @@ def fit(
labels: np.ndarray,
identifiers: List[Tuple[int, int, float]],
) -> AbstractEnsemble:
"""
Builds a ensemble given the individual models out of fold predictions.
Fundamentally, defines a set of weights on how to perform a soft-voting
aggregation of the models in the given identifiers.

Args:
predictions (List[np.array]):
A list of individual model predictions of shape (n_datapoints, n_targets)
corresponding to the OutOfFold estimate of the ground truth
labels (np.ndarray):
The ground truth targets of shape (n_datapoints, n_targets)
identifiers: List[Tuple[int, int, float]]
A list of model identifiers, each with the form
(seed, number of run, budget)

Returns:
A copy of self
"""
self.ensemble_size = int(self.ensemble_size)
if self.ensemble_size < 1:
raise ValueError('Ensemble size cannot be less than one!')
Expand All @@ -53,7 +71,20 @@ def _fit(
predictions: List[np.ndarray],
labels: np.ndarray,
) -> None:
"""Fast version of Rich Caruana's ensemble selection method."""
"""
Fast version of Rich Caruana's ensemble selection method.

For more details, please check the paper
"Ensemble Selection from Library of Models" by R Caruana (2004)

Args:
predictions (List[np.array]):
A list of individual model predictions of shape (n_datapoints, n_targets)
corresponding to the OutOfFold estimate of the ground truth
identifiers (List[Tuple[int, int, float]]):
A list of model identifiers, each with the form
(seed, number of run, budget)
"""
self.num_input_models_ = len(predictions)

ensemble = [] # type: List[np.ndarray]
Expand All @@ -71,60 +102,47 @@ def _fit(
dtype=np.float64,
)
for i in range(ensemble_size):
scores = np.zeros(
losses = np.zeros(
(len(predictions)),
dtype=np.float64,
)
s = len(ensemble)
if s == 0:
weighted_ensemble_prediction.fill(0.0)
else:
weighted_ensemble_prediction.fill(0.0)
for pred in ensemble:
np.add(
weighted_ensemble_prediction,
pred,
out=weighted_ensemble_prediction,
)
np.multiply(
weighted_ensemble_prediction,
1 / s,
out=weighted_ensemble_prediction,
)
np.multiply(
if s > 0:
np.add(
weighted_ensemble_prediction,
(s / float(s + 1)),
ensemble[-1],
out=weighted_ensemble_prediction,
)

# Memory-efficient averaging!
for j, pred in enumerate(predictions):
# Memory-efficient averaging!
fant_ensemble_prediction.fill(0.0)
# fant_ensemble_prediction is the prediction of the current ensemble
# and should be ([predictions[selected_prev_iterations] + predictions[j])/(s+1)
# We overwrite the contents of fant_ensemble_prediction
# directly with weighted_ensemble_prediction + new_prediction and then scale for avg
np.add(
fant_ensemble_prediction,
weighted_ensemble_prediction,
pred,
out=fant_ensemble_prediction
)
np.add(
np.multiply(
fant_ensemble_prediction,
(1. / float(s + 1)) * pred,
(1. / float(s + 1)),
out=fant_ensemble_prediction
)

# Calculate score is versatile and can return a dict of score
# when all_scoring_functions=False, we know it will be a float
score = calculate_score(
# Calculate loss is versatile and can return a dict of slosses
losses[j] = calculate_loss(
metrics=[self.metric],
target=labels,
prediction=fant_ensemble_prediction,
task_type=self.task_type,
)
scores[j] = self.metric._optimum - score[self.metric.name]
)[self.metric.name]

all_best = np.argwhere(scores == np.nanmin(scores)).flatten()
all_best = np.argwhere(losses == np.nanmin(losses)).flatten()
best = self.random_state.choice(all_best)
ensemble.append(predictions[best])
trajectory.append(scores[best])
trajectory.append(losses[best])
order.append(best)

# Handle special case
Expand All @@ -133,9 +151,15 @@ def _fit(

self.indices_ = order
self.trajectory_ = trajectory
self.train_score_ = trajectory[-1]
self.train_loss_ = trajectory[-1]

def _calculate_weights(self) -> None:
"""
Calculates the contribution each of the individual models
should have, in the final ensemble soft voting. It does so by
a frequency counting scheme. In particular, how many times a model
was used during hill climbing optimization.
"""
ensemble_members = Counter(self.indices_).most_common()
weights = np.zeros(
(self.num_input_models_,),
Expand All @@ -151,6 +175,19 @@ def _calculate_weights(self) -> None:
self.weights_ = weights

def predict(self, predictions: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
"""
Given a list of predictions from the individual model, this method
aggregates the predictions using a soft voting scheme with the weights
found during training.

Args:
predictions (List[np.ndarray]):
A list of predictions from the individual base models.

Returns:
average (np.array): Soft voting predictions of ensemble models, using
the weights found during ensemble selection (self._weights)
"""

average = np.zeros_like(predictions[0], dtype=np.float64)
tmp_predictions = np.empty_like(predictions[0], dtype=np.float64)
Expand Down Expand Up @@ -191,6 +228,19 @@ def get_models_with_weights(
self,
models: Dict[Any, BasePipeline]
) -> List[Tuple[float, BasePipeline]]:
"""
Handy function to tag the provided input models with a given weight.

Args:
models (List[Tuple[float, BasePipeline]]):
A dictionary that maps a model's name to it's actual python object.

Returns:
output (List[Tuple[float, BasePipeline]]):
each model with the related weight, sorted by ascending
performance. Notice that ensemble selection solves a minimization
problem.
"""
output = []
for i, weight in enumerate(self.weights_):
if weight > 0.0:
Expand All @@ -203,6 +253,15 @@ def get_models_with_weights(
return output

def get_selected_model_identifiers(self) -> List[Tuple[int, int, float]]:
"""
After training of ensemble selection, not all models will be used.
Some of them will have zero weight. This procedure filters this models
out.

Returns:
output (List[Tuple[int, int, float]]):
The models actually used by ensemble selection
"""
output = []

for i, weight in enumerate(self.weights_):
Expand All @@ -213,4 +272,11 @@ def get_selected_model_identifiers(self) -> List[Tuple[int, int, float]]:
return output

def get_validation_performance(self) -> float:
"""
Returns the best optimization performance seen during hill climbing

Returns:
(float):
best ensemble training performance
"""
return self.trajectory_[-1]
21 changes: 6 additions & 15 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.components.training.metrics.utils import (
calculate_score,
calculate_loss,
get_metrics,
)
from autoPyTorch.utils.backend import Backend
Expand Down Expand Up @@ -364,30 +364,21 @@ def _get_pipeline(self) -> BaseEstimator:
def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]:
"""SMAC follows a minimization goal, so the make_scorer
sign is used as a guide to obtain the value to reduce.
The calculate_loss internally translate a score function to
a minimization problem

On this regard, to optimize a metric:
1- score is calculared with calculate_score, with the caveat, that if
for the metric greater is not better, a negative score is returned.
2- the err (the optimization goal) is then:
optimum - (metric.sign * actual_score)
For accuracy for example: optimum(1) - (+1 * actual score)
For logloss for example: optimum(0) - (-1 * actual score)
"""

if not isinstance(self.configuration, Configuration):
return {self.metric.name: 1.0}
return {self.metric.name: self.metric._worst_possible_result}

if self.additional_metrics is not None:
metrics = self.additional_metrics
else:
metrics = [self.metric]
score = calculate_score(
y_true, y_hat, self.task_type, metrics)

err = {metric.name: metric._optimum - score[metric.name] for metric in metrics
if metric.name in score.keys()}

return err
return calculate_loss(
y_true, y_hat, self.task_type, metrics)

def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
opt_pred: np.ndarray, valid_pred: Optional[np.ndarray],
Expand Down
58 changes: 52 additions & 6 deletions autoPyTorch/pipeline/components/training/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ def get_metrics(dataset_properties: Dict[str, Any],


def calculate_score(
target: np.ndarray,
prediction: np.ndarray,
task_type: int,
metrics: Iterable[autoPyTorchMetric],
target: np.ndarray,
prediction: np.ndarray,
task_type: int,
metrics: Iterable[autoPyTorchMetric],
) -> Dict[str, float]:
score_dict = dict()
if task_type in REGRESSION_TASKS:
cprediction = sanitize_array(prediction)
for metric_ in metrics:
try:
score_dict[metric_.name] = metric_(target, cprediction)
score_dict[metric_.name] = metric_._sign * metric_(target, cprediction)
except ValueError as e:
warnings.warn(f"{e} {e.args[0]}")
if e.args[0] == "Mean Squared Logarithmic Error cannot be used when " \
Expand All @@ -126,7 +126,7 @@ def calculate_score(
else:
for metric_ in metrics:
try:
score_dict[metric_.name] = metric_(target, prediction)
score_dict[metric_.name] = metric_._sign * metric_(target, prediction)
except ValueError as e:
if e.args[0] == 'multiclass format is not supported':
continue
Expand All @@ -143,3 +143,49 @@ def calculate_score(
else:
raise e
return score_dict


def calculate_loss(
target: np.ndarray,
prediction: np.ndarray,
task_type: int,
metrics: Iterable[autoPyTorchMetric],
) -> Dict[str, float]:
"""
Returns a loss (a magnitude that allows casting the
optimization problem, as a minimization one) for the
given Auto-Sklearn Scorer object
Parameters
----------
solution: np.ndarray
The ground truth of the targets
prediction: np.ndarray
The best estimate from the model, of the given targets
task_type: int
To understand if the problem task is classification
or regression
metric: Scorer
Object that host a function to calculate how good the
prediction is according to the solution.
scoring_functions: List[Scorer]
A list of metrics to calculate multiple losses
Returns
-------
float or Dict[str, float]
A loss function for each of the provided scorer objects
"""
score = calculate_score(
target=target,
prediction=prediction,
task_type=task_type,
metrics=metrics,
)

loss_dict = dict()
for metric_ in metrics:
# TODO: When metrics are annotated with type_of_target support
# we can remove this check
if metric_.name not in score:
continue
loss_dict[metric_.name] = metric_._optimum - metric_._sign * score[metric_.name]
return loss_dict
4 changes: 2 additions & 2 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
'.autoPyTorch/ensemble_read_preds.pkl',
'.autoPyTorch/start_time_1',
'.autoPyTorch/ensemble_history.json',
'.autoPyTorch/ensemble_read_scores.pkl',
'.autoPyTorch/ensemble_read_losses.pkl',
'.autoPyTorch/true_targets_ensemble.npy',
]
for expected_file in expected_files:
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
'.autoPyTorch/ensemble_read_preds.pkl',
'.autoPyTorch/start_time_1',
'.autoPyTorch/ensemble_history.json',
'.autoPyTorch/ensemble_read_scores.pkl',
'.autoPyTorch/ensemble_read_losses.pkl',
'.autoPyTorch/true_targets_ensemble.npy',
]
for expected_file in expected_files:
Expand Down
Loading