@@ -433,34 +433,16 @@ def __init__(self, backend: Backend,
433
433
self .backend : Backend = backend
434
434
self .queue = queue
435
435
436
- self .datamanager : BaseDataset = self .backend .load_datamanager ()
437
-
438
- assert self .datamanager .task_type is not None , \
439
- "Expected dataset {} to have task_type got None" .format (self .datamanager .__class__ .__name__ )
440
- self .task_type = STRING_TO_TASK_TYPES [self .datamanager .task_type ]
441
- self .output_type = STRING_TO_OUTPUT_TYPES [self .datamanager .output_type ]
442
- self .issparse = self .datamanager .issparse
443
-
444
436
self .include = include
445
437
self .exclude = exclude
446
438
self .search_space_updates = search_space_updates
447
439
448
- self .X_train , self .y_train = self .datamanager .train_tensors
449
-
450
- if self .datamanager .val_tensors is not None :
451
- self .X_valid , self .y_valid = self .datamanager .val_tensors
452
- else :
453
- self .X_valid , self .y_valid = None , None
454
-
455
- if self .datamanager .test_tensors is not None :
456
- self .X_test , self .y_test = self .datamanager .test_tensors
457
- else :
458
- self .X_test , self .y_test = None , None
459
-
460
440
self .metric = metric
461
441
462
442
self .seed = seed
463
443
444
+ self ._init_datamanager_info ()
445
+
464
446
# Flag to save target for ensemble
465
447
self .output_y_hat_optimization = output_y_hat_optimization
466
448
@@ -497,12 +479,6 @@ def __init__(self, backend: Backend,
497
479
else :
498
480
raise ValueError ('task {} not available' .format (self .task_type ))
499
481
self .predict_function = self ._predict_proba
500
- self .dataset_properties = self .datamanager .get_dataset_properties (
501
- get_dataset_requirements (info = self .datamanager .get_required_dataset_info (),
502
- include = self .include ,
503
- exclude = self .exclude ,
504
- search_space_updates = self .search_space_updates
505
- ))
506
482
507
483
self .additional_metrics : Optional [List [autoPyTorchMetric ]] = None
508
484
metrics_dict : Optional [Dict [str , List [str ]]] = None
@@ -542,6 +518,53 @@ def __init__(self, backend: Backend,
542
518
self .logger .debug ("Fit dictionary in Abstract evaluator: {}" .format (dict_repr (self .fit_dictionary )))
543
519
self .logger .debug ("Search space updates :{}" .format (self .search_space_updates ))
544
520
521
+ def _init_datamanager_info (
522
+ self ,
523
+ ) -> None :
524
+ """
525
+ Initialises instance attributes that come from the datamanager.
526
+ For example,
527
+ X_train, y_train, etc.
528
+ """
529
+
530
+ datamanager : BaseDataset = self .backend .load_datamanager ()
531
+
532
+ assert datamanager .task_type is not None , \
533
+ "Expected dataset {} to have task_type got None" .format (datamanager .__class__ .__name__ )
534
+ self .task_type = STRING_TO_TASK_TYPES [datamanager .task_type ]
535
+ self .output_type = STRING_TO_OUTPUT_TYPES [datamanager .output_type ]
536
+ self .issparse = datamanager .issparse
537
+
538
+ self .X_train , self .y_train = datamanager .train_tensors
539
+
540
+ if datamanager .val_tensors is not None :
541
+ self .X_valid , self .y_valid = datamanager .val_tensors
542
+ else :
543
+ self .X_valid , self .y_valid = None , None
544
+
545
+ if datamanager .test_tensors is not None :
546
+ self .X_test , self .y_test = datamanager .test_tensors
547
+ else :
548
+ self .X_test , self .y_test = None , None
549
+
550
+ self .resampling_strategy = datamanager .resampling_strategy
551
+
552
+ self .num_classes : Optional [int ] = getattr (datamanager , "num_classes" , None )
553
+
554
+ self .dataset_properties = datamanager .get_dataset_properties (
555
+ get_dataset_requirements (info = datamanager .get_required_dataset_info (),
556
+ include = self .include ,
557
+ exclude = self .exclude ,
558
+ search_space_updates = self .search_space_updates
559
+ ))
560
+ self .splits = datamanager .splits
561
+ if self .splits is None :
562
+ raise AttributeError (f"create_splits on { datamanager .__class__ .__name__ } must be called "
563
+ f"before the instantiation of { self .__class__ .__name__ } " )
564
+
565
+ # delete datamanager from memory
566
+ del datamanager
567
+
545
568
def _init_fit_dictionary (
546
569
self ,
547
570
logger_port : int ,
@@ -988,21 +1011,20 @@ def _ensure_prediction_array_sizes(self, prediction: np.ndarray,
988
1011
(np.ndarray):
989
1012
The formatted prediction
990
1013
"""
991
- assert self .datamanager .num_classes is not None , "Called function on wrong task"
992
- num_classes : int = self .datamanager .num_classes
1014
+ assert self .num_classes is not None , "Called function on wrong task"
993
1015
994
1016
if self .output_type == MULTICLASS and \
995
- prediction .shape [1 ] < num_classes :
1017
+ prediction .shape [1 ] < self . num_classes :
996
1018
if Y_train is None :
997
1019
raise ValueError ('Y_train must not be None!' )
998
1020
classes = list (np .unique (Y_train ))
999
1021
1000
1022
mapping = dict ()
1001
- for class_number in range (num_classes ):
1023
+ for class_number in range (self . num_classes ):
1002
1024
if class_number in classes :
1003
1025
index = classes .index (class_number )
1004
1026
mapping [index ] = class_number
1005
- new_predictions = np .zeros ((prediction .shape [0 ], num_classes ),
1027
+ new_predictions = np .zeros ((prediction .shape [0 ], self . num_classes ),
1006
1028
dtype = np .float32 )
1007
1029
1008
1030
for index in mapping :
0 commit comments