Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure the performance of pipeline is at least 0.8 #82

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def set_pipeline_config(
then sets them to the current pipeline
configuration.
Args:
**pipeline_config_kwargs: Valid config options include "job_id",
**pipeline_config_kwargs: Valid config options include "num_run",
"device", "budget_type", "epochs", "runtime", "torch_num_threads",
"early_stopping", "use_tensorboard_logger", "use_pynisher",
"metrics_during_training"
Expand Down Expand Up @@ -923,7 +923,7 @@ def refit(
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'job_id': 0
'num_run': 0
})
X.update({**self.pipeline_options, **budget_config})
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
Expand Down Expand Up @@ -996,7 +996,7 @@ def fit(self,
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'job_id': 0
'num_run': 0
})
X.update({**self.pipeline_options, **budget_config})

Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _fit_and_predict(self, pipeline: BaseEstimator, fold: int, train_indices: Un
X = {'train_indices': train_indices,
'val_indices': test_indices,
'split_id': fold,
'job_id': self.num_run,
'num_run': self.num_run,
**self.fit_dictionary} # fit dictionary
y = None
fit_and_suppress_warnings(self.logger, pipeline, X, y)
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def get_additional_run_info(self) -> Dict:
@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
return {
'job_id': '1',
'num_run': 0,
'device': 'cpu',
'budget_type': 'epochs',
'epochs': 5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ def get_hyperparameter_search_space(
raise ValueError("No scheduler found")

if default is None:
defaults = ['no_LRScheduler',
'LambdaLR',
'StepLR',
'ExponentialLR',
'CosineAnnealingLR',
'ReduceLROnPlateau'
]
defaults = [
'ReduceLROnPlateau',
'CosineAnnealingLR',
'no_LRScheduler',
'LambdaLR',
'StepLR',
'ExponentialLR',
]
for default_ in defaults:
if default_ in available_schedulers:
default = default_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_best_epoch(self, loss_type: str = 'val_loss') -> int:
[self.performance_tracker[loss_type][e] for e in range(1, len(
self.performance_tracker[loss_type]) + 1
)]
)
) + 1 # Epochs start at 1

def get_last_epoch(self) -> int:
if 'train_loss' not in self.performance_tracker:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import logging.handlers
import os
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple, cast

Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(self,
self.writer = None # type: Optional[SummaryWriter]
self._fit_requirements: Optional[List[FitRequirement]] = [
FitRequirement("lr_scheduler", (_LRScheduler,), user_defined=False, dataset_property=False),
FitRequirement("job_id", (str,), user_defined=False, dataset_property=False),
FitRequirement("num_run", (int,), user_defined=False, dataset_property=False),
FitRequirement(
"optimizer", (Optimizer,), user_defined=False, dataset_property=False),
FitRequirement("train_data_loader",
Expand All @@ -75,6 +76,7 @@ def __init__(self,
FitRequirement("val_data_loader",
(torch.utils.data.DataLoader,),
user_defined=False, dataset_property=False)]
self.checkpoint_dir = None # type: Optional[str]

def get_fit_requirements(self) -> Optional[List[FitRequirement]]:
return self._fit_requirements
Expand Down Expand Up @@ -185,7 +187,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom

# Setup the logger
self.logger = get_named_client_logger(
name=X['job_id'],
name=X['num_run'],
# Log to a user provided port else to the default logging port
port=X['logger_port'
] if 'logger_port' in X else logging.handlers.DEFAULT_TCP_LOGGING_PORT,
Expand Down Expand Up @@ -369,8 +371,29 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
bool: If true, training should be stopped
"""
assert self.run_summary is not None
epochs_since_best = self.run_summary.get_best_epoch() - self.run_summary.get_last_epoch()

# Allow to disable early stopping
if X['early_stopping'] is None or X['early_stopping'] < 0:
return False

# Store the best weights seen so far:
if self.checkpoint_dir is None:
self.checkpoint_dir = tempfile.mkdtemp(dir=X['backend'].temporary_directory)

epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch()

# Save the checkpoint if there is a new best epoch
best_path = os.path.join(self.checkpoint_dir, 'best.pth')
if epochs_since_best == 0:
torch.save(X['network'].state_dict(), best_path)

if epochs_since_best > X['early_stopping']:
self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}")
# We will stop the training. Load the last best performing weights
X['network'].load_state_dict(torch.load(best_path))

# Let the tempfile module clean the temp dir
self.checkpoint_dir = None
return True

return False
Expand Down Expand Up @@ -458,8 +481,8 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None:
X['budget_type']
))

if 'job_id' not in X:
raise ValueError('To fit a trainer, expected fit dictionary to have a job_id')
if 'num_run' not in X:
raise ValueError('To fit a trainer, expected fit dictionary to have a num_run')

for config_option in ["torch_num_threads", 'device']:
if config_option not in X:
Expand Down
7 changes: 7 additions & 0 deletions examples/example_tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
Tabular Classification
======================
"""
import os
import tempfile as tmp
import typing
import warnings

os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

Expand Down
6 changes: 3 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def fit_dictionary_numerical_only(backend):
'X_train': X,
'y_train': y,
'dataset_properties': dataset_properties,
'job_id': 'example_tabular_classification_1',
'num_run': np.random.randint(50),
'device': 'cpu',
'budget_type': 'epochs',
'epochs': 1,
Expand Down Expand Up @@ -220,7 +220,7 @@ def fit_dictionary_categorical_only(backend):
'X_train': X,
'y_train': y,
'dataset_properties': dataset_properties,
'job_id': 'example_tabular_classification_1',
'num_run': np.random.randint(50),
'device': 'cpu',
'budget_type': 'epochs',
'epochs': 1,
Expand Down Expand Up @@ -262,7 +262,7 @@ def fit_dictionary_num_and_categorical(backend):
'X_train': X,
'y_train': y,
'dataset_properties': dataset_properties,
'job_id': 'example_tabular_classification_1',
'num_run': np.random.randint(50),
'device': 'cpu',
'budget_type': 'epochs',
'epochs': 1,
Expand Down
33 changes: 29 additions & 4 deletions test/test_pipeline/components/test_setup_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def test_pipeline_fit(self, fit_dictionary, backbone, head):
assert backbone == config.get('network_backbone:__choice__', None)
assert head == config.get('network_head:__choice__', None)
pipeline.set_hyperparameters(config)

# Need more epochs to make sure validation performance is met
fit_dictionary['epochs'] = 100
# Early stop to the best configuration seen
fit_dictionary['early_stopping'] = 50

pipeline.fit(fit_dictionary)

# To make sure we fitted the model, there should be a
Expand All @@ -44,9 +50,28 @@ def test_pipeline_fit(self, fit_dictionary, backbone, head):
assert run_summary.total_parameter_count > 0
assert 'accuracy' in run_summary.performance_tracker['train_metrics'][1]

# Commented out the next line as some pipelines are not
# achieving this accuracy with default configuration and 10 epochs
# To be added once we fix the search space
# assert run_summary.performance_tracker['val_metrics'][fit_dictionary['epochs']]['accuracy'] >= 0.8
# Make sure default pipeline achieves a good score for dummy datasets
epoch2loss = run_summary.performance_tracker['val_loss']
best_loss = min(list(epoch2loss.values()))
epoch_where_best = list(epoch2loss.keys())[list(epoch2loss.values()).index(best_loss)]
score = run_summary.performance_tracker['val_metrics'][epoch_where_best]['accuracy']

assert score >= 0.8, run_summary.performance_tracker['val_metrics']

# Check that early stopping happened, if it did

# We should not stop before patience
assert run_summary.get_last_epoch() >= fit_dictionary['early_stopping']

# we should not be greater than max allowed epoch
assert run_summary.get_last_epoch() <= fit_dictionary['epochs']

# every trained epoch has a val metric
assert run_summary.get_last_epoch() == max(list(run_summary.performance_tracker['train_metrics'].keys()))

epochs_since_best = run_summary.get_last_epoch() - run_summary.get_best_epoch()
if epochs_since_best >= fit_dictionary['early_stopping']:
assert run_summary.get_best_epoch() == epoch_where_best

# Make sure a network was fit
assert isinstance(pipeline.named_steps['network'].get_network(), torch.nn.Module)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_tabular_preprocess(self):
val_indices=[6, 7, 8, 9],
dataset_properties=dataset_properties,
# Training configuration
job_id='test',
num_run=15,
device='cpu',
budget_type='epochs',
epochs=10,
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_tabular_no_preprocess(self):
val_indices=[6, 7, 8, 9],
dataset_properties=dataset_properties,
# Training configuration
job_id='test',
num_run=16,
device='cpu',
budget_type='epochs',
epochs=10,
Expand Down
2 changes: 1 addition & 1 deletion test/test_pipeline/test_tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_remove_key_check_requirements(self, fit_dictionary):
"""Makes sure that when a key is removed from X, correct error is outputted"""
pipeline = TabularClassificationPipeline(
dataset_properties=fit_dictionary['dataset_properties'])
for key in ['job_id', 'device', 'split_id', 'use_pynisher', 'torch_num_threads',
for key in ['num_run', 'device', 'split_id', 'use_pynisher', 'torch_num_threads',
'dataset_properties', ]:
fit_dictionary_copy = fit_dictionary.copy()
fit_dictionary_copy.pop(key)
Expand Down