Skip to content

Commit 4a0c773

Browse files
[FIX] Datamanager in memory (#382)
* remove datamanager instances from evaluation and smbo * fix flake * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * fix flake Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent b5c1757 commit 4a0c773

File tree

4 files changed

+58
-57
lines changed

4 files changed

+58
-57
lines changed

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -433,34 +433,16 @@ def __init__(self, backend: Backend,
433433
self.backend: Backend = backend
434434
self.queue = queue
435435

436-
self.datamanager: BaseDataset = self.backend.load_datamanager()
437-
438-
assert self.datamanager.task_type is not None, \
439-
"Expected dataset {} to have task_type got None".format(self.datamanager.__class__.__name__)
440-
self.task_type = STRING_TO_TASK_TYPES[self.datamanager.task_type]
441-
self.output_type = STRING_TO_OUTPUT_TYPES[self.datamanager.output_type]
442-
self.issparse = self.datamanager.issparse
443-
444436
self.include = include
445437
self.exclude = exclude
446438
self.search_space_updates = search_space_updates
447439

448-
self.X_train, self.y_train = self.datamanager.train_tensors
449-
450-
if self.datamanager.val_tensors is not None:
451-
self.X_valid, self.y_valid = self.datamanager.val_tensors
452-
else:
453-
self.X_valid, self.y_valid = None, None
454-
455-
if self.datamanager.test_tensors is not None:
456-
self.X_test, self.y_test = self.datamanager.test_tensors
457-
else:
458-
self.X_test, self.y_test = None, None
459-
460440
self.metric = metric
461441

462442
self.seed = seed
463443

444+
self._init_datamanager_info()
445+
464446
# Flag to save target for ensemble
465447
self.output_y_hat_optimization = output_y_hat_optimization
466448

@@ -497,12 +479,6 @@ def __init__(self, backend: Backend,
497479
else:
498480
raise ValueError('task {} not available'.format(self.task_type))
499481
self.predict_function = self._predict_proba
500-
self.dataset_properties = self.datamanager.get_dataset_properties(
501-
get_dataset_requirements(info=self.datamanager.get_required_dataset_info(),
502-
include=self.include,
503-
exclude=self.exclude,
504-
search_space_updates=self.search_space_updates
505-
))
506482

507483
self.additional_metrics: Optional[List[autoPyTorchMetric]] = None
508484
metrics_dict: Optional[Dict[str, List[str]]] = None
@@ -542,6 +518,53 @@ def __init__(self, backend: Backend,
542518
self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(dict_repr(self.fit_dictionary)))
543519
self.logger.debug("Search space updates :{}".format(self.search_space_updates))
544520

521+
def _init_datamanager_info(
522+
self,
523+
) -> None:
524+
"""
525+
Initialises instance attributes that come from the datamanager.
526+
For example,
527+
X_train, y_train, etc.
528+
"""
529+
530+
datamanager: BaseDataset = self.backend.load_datamanager()
531+
532+
assert datamanager.task_type is not None, \
533+
"Expected dataset {} to have task_type got None".format(datamanager.__class__.__name__)
534+
self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type]
535+
self.output_type = STRING_TO_OUTPUT_TYPES[datamanager.output_type]
536+
self.issparse = datamanager.issparse
537+
538+
self.X_train, self.y_train = datamanager.train_tensors
539+
540+
if datamanager.val_tensors is not None:
541+
self.X_valid, self.y_valid = datamanager.val_tensors
542+
else:
543+
self.X_valid, self.y_valid = None, None
544+
545+
if datamanager.test_tensors is not None:
546+
self.X_test, self.y_test = datamanager.test_tensors
547+
else:
548+
self.X_test, self.y_test = None, None
549+
550+
self.resampling_strategy = datamanager.resampling_strategy
551+
552+
self.num_classes: Optional[int] = getattr(datamanager, "num_classes", None)
553+
554+
self.dataset_properties = datamanager.get_dataset_properties(
555+
get_dataset_requirements(info=datamanager.get_required_dataset_info(),
556+
include=self.include,
557+
exclude=self.exclude,
558+
search_space_updates=self.search_space_updates
559+
))
560+
self.splits = datamanager.splits
561+
if self.splits is None:
562+
raise AttributeError(f"create_splits on {datamanager.__class__.__name__} must be called "
563+
f"before the instantiation of {self.__class__.__name__}")
564+
565+
# delete datamanager from memory
566+
del datamanager
567+
545568
def _init_fit_dictionary(
546569
self,
547570
logger_port: int,
@@ -988,21 +1011,20 @@ def _ensure_prediction_array_sizes(self, prediction: np.ndarray,
9881011
(np.ndarray):
9891012
The formatted prediction
9901013
"""
991-
assert self.datamanager.num_classes is not None, "Called function on wrong task"
992-
num_classes: int = self.datamanager.num_classes
1014+
assert self.num_classes is not None, "Called function on wrong task"
9931015

9941016
if self.output_type == MULTICLASS and \
995-
prediction.shape[1] < num_classes:
1017+
prediction.shape[1] < self.num_classes:
9961018
if Y_train is None:
9971019
raise ValueError('Y_train must not be None!')
9981020
classes = list(np.unique(Y_train))
9991021

10001022
mapping = dict()
1001-
for class_number in range(num_classes):
1023+
for class_number in range(self.num_classes):
10021024
if class_number in classes:
10031025
index = classes.index(class_number)
10041026
mapping[index] = class_number
1005-
new_predictions = np.zeros((prediction.shape[0], num_classes),
1027+
new_predictions = np.zeros((prediction.shape[0], self.num_classes),
10061028
dtype=np.float32)
10071029

10081030
for index in mapping:

autoPyTorch/evaluation/test_evaluator.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,12 @@ def __init__(
145145
search_space_updates=search_space_updates
146146
)
147147

148-
if not isinstance(self.datamanager.resampling_strategy, (NoResamplingStrategyTypes)):
149-
resampling_strategy = self.datamanager.resampling_strategy
148+
if not isinstance(self.resampling_strategy, (NoResamplingStrategyTypes)):
150149
raise ValueError(
151150
f'resampling_strategy for TestEvaluator must be in '
152-
f'NoResamplingStrategyTypes, but got {resampling_strategy}'
151+
f'NoResamplingStrategyTypes, but got {self.resampling_strategy}'
153152
)
154153

155-
self.splits = self.datamanager.splits
156-
if self.splits is None:
157-
raise AttributeError("create_splits must be called in {}".format(self.datamanager.__class__.__name__))
158-
159154
def fit_predict_and_loss(self) -> None:
160155

161156
split_id = 0

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,12 @@ def __init__(self, backend: Backend, queue: Queue,
152152
search_space_updates=search_space_updates
153153
)
154154

155-
if not isinstance(self.datamanager.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
156-
resampling_strategy = self.datamanager.resampling_strategy
155+
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
157156
raise ValueError(
158157
f'resampling_strategy for TrainEvaluator must be in '
159-
f'(CrossValTypes, HoldoutValTypes), but got {resampling_strategy}'
158+
f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}'
160159
)
161160

162-
self.splits = self.datamanager.splits
163-
if self.splits is None:
164-
raise AttributeError("Must have called create_splits on {}".format(self.datamanager.__class__.__name__))
165161
self.num_folds: int = len(self.splits)
166162
self.Y_targets: List[Optional[np.ndarray]] = [None] * self.num_folds
167163
self.Y_train_targets: np.ndarray = np.ones(self.y_train.shape) * np.NaN

autoPyTorch/optimizer/smbo.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from smac.utils.io.traj_logging import TrajEntry
1919

2020
from autoPyTorch.automl_common.common.utils.backend import Backend
21-
from autoPyTorch.datasets.base_dataset import BaseDataset
2221
from autoPyTorch.datasets.resampling_strategy import (
2322
CrossValTypes,
2423
DEFAULT_RESAMPLING_PARAMETERS,
@@ -194,9 +193,8 @@ def __init__(self,
194193
super(AutoMLSMBO, self).__init__()
195194
# data related
196195
self.dataset_name = dataset_name
197-
self.datamanager: Optional[BaseDataset] = None
198196
self.metric = metric
199-
self.task: Optional[str] = None
197+
200198
self.backend = backend
201199
self.all_supported_metrics = all_supported_metrics
202200

@@ -252,21 +250,11 @@ def __init__(self,
252250
self.initial_configurations = initial_configurations \
253251
if len(initial_configurations) > 0 else None
254252

255-
def reset_data_manager(self) -> None:
256-
if self.datamanager is not None:
257-
del self.datamanager
258-
self.datamanager = self.backend.load_datamanager()
259-
260-
if self.datamanager is not None and self.datamanager.task_type is not None:
261-
self.task = self.datamanager.task_type
262-
263253
def run_smbo(self, func: Optional[Callable] = None
264254
) -> Tuple[RunHistory, List[TrajEntry], str]:
265255

266256
self.watcher.start_task('SMBO')
267257
self.logger.info("Started run of SMBO")
268-
# == first things first: load the datamanager
269-
self.reset_data_manager()
270258

271259
# == Initialize non-SMBO stuff
272260
# first create a scenario

0 commit comments

Comments
 (0)