1
1
from __future__ import annotations
2
2
3
- from typing import Any , Dict , Iterable , Sequence , Type , cast
3
+ from typing import Any , Iterable , Mapping , Sequence , Type , cast
4
4
5
5
import logging .handlers
6
6
import multiprocessing
@@ -46,7 +46,7 @@ def __init__(
46
46
task_type : int ,
47
47
metrics : Sequence [Scorer ],
48
48
ensemble_class : Type [AbstractEnsemble ] = EnsembleSelection ,
49
- ensemble_kwargs : Dict [str , Any ] | None = None ,
49
+ ensemble_kwargs : Mapping [str , Any ] | None = None ,
50
50
ensemble_nbest : int | float = 50 ,
51
51
max_models_on_disc : int | float | None = 100 ,
52
52
seed : int = 1 ,
@@ -71,9 +71,11 @@ def __init__(
71
71
metrics: Sequence[Scorer]
72
72
Metrics to optimize the ensemble for. These must be non-duplicated.
73
73
74
- ensemble_class
74
+ ensemble_class: Type[AbstractEnsemble]
75
+ Implementation of the ensemble algorithm.
75
76
76
- ensemble_kwargs
77
+ ensemble_kwargs: Mapping[str, Any] | None
78
+ Arguments passed to the constructor of the ensemble algorithm.
77
79
78
80
ensemble_nbest: int | float = 50
79
81
@@ -169,6 +171,8 @@ def __init__(
169
171
self .validation_performance_ = np .inf
170
172
171
173
# Data we may need
174
+ # TODO: The test data is needlessly loaded but automl_common has no concept of
175
+ # these and is perhaps too rigid
172
176
datamanager : XYDataManager = self .backend .load_datamanager ()
173
177
self ._X_test : SUPPORTED_FEAT_TYPES | None = datamanager .data .get ("X_test" , None )
174
178
self ._y_test : np .ndarray | None = datamanager .data .get ("Y_test" , None )
@@ -442,6 +446,17 @@ def main(
442
446
self .logger .debug ("Found no runs" )
443
447
raise RuntimeError ("Found no runs" )
444
448
449
+ # We load in `X_data` if we need it
450
+ if any (m ._needs_X for m in self .metrics ):
451
+ ensemble_X_data = self .X_data ("ensemble" )
452
+
453
+ if ensemble_X_data is None :
454
+ msg = "No `X_data` for 'ensemble' which was required by metrics"
455
+ self .logger .debug (msg )
456
+ raise RuntimeError (msg )
457
+ else :
458
+ ensemble_X_data = None
459
+
445
460
# Calculate the loss for those that require it
446
461
requires_update = self .requires_loss_update (runs )
447
462
if self .read_at_most is not None :
@@ -450,9 +465,7 @@ def main(
450
465
for run in requires_update :
451
466
run .record_modified_times () # So we don't count as modified next time
452
467
run .losses = {
453
- metric .name : self .loss (
454
- run , metric = metric , X_data = self .X_data ("ensemble" )
455
- )
468
+ metric .name : self .loss (run , metric = metric , X_data = ensemble_X_data )
456
469
for metric in self .metrics
457
470
}
458
471
@@ -549,15 +562,14 @@ def main(
549
562
return self .ensemble_history , self .ensemble_nbest
550
563
551
564
targets = cast (np .ndarray , self .targets ("ensemble" )) # Sure they exist
552
- X_data = self .X_data ("ensemble" )
553
565
554
566
ensemble = self .fit_ensemble (
555
567
candidates = candidates ,
556
- X_data = X_data ,
557
568
targets = targets ,
558
569
runs = runs ,
559
570
ensemble_class = self .ensemble_class ,
560
571
ensemble_kwargs = self .ensemble_kwargs ,
572
+ X_data = ensemble_X_data ,
561
573
task = self .task_type ,
562
574
metrics = self .metrics ,
563
575
precision = self .precision ,
@@ -587,7 +599,15 @@ def main(
587
599
588
600
run_preds = [r .predictions (kind , precision = self .precision ) for r in models ]
589
601
pred = ensemble .predict (run_preds )
590
- X_data = self .X_data (kind )
602
+
603
+ if any (m ._needs_X for m in self .metrics ):
604
+ X_data = self .X_data (kind )
605
+ if X_data is None :
606
+ msg = f"No `X` data for '{ kind } ' which was required by metrics"
607
+ self .logger .debug (msg )
608
+ raise RuntimeError (msg )
609
+ else :
610
+ X_data = None
591
611
592
612
scores = calculate_scores (
593
613
solution = pred_targets ,
@@ -597,10 +617,19 @@ def main(
597
617
X_data = X_data ,
598
618
scoring_functions = None ,
599
619
)
620
+
621
+ # TODO only one metric in history
622
+ #
623
+ # We should probably return for all metrics but this makes
624
+ # automl::performance_history a lot more complicated, will
625
+ # tackle in a future PR
626
+ first_metric = self .metrics [0 ]
600
627
performance_stamp [f"ensemble_{ score_name } _score" ] = scores [
601
- self . metrics [ 0 ] .name
628
+ first_metric .name
602
629
]
603
- self .ensemble_history .append (performance_stamp )
630
+
631
+ # Add the performance stamp to the history
632
+ self .ensemble_history .append (performance_stamp )
604
633
605
634
# Lastly, delete any runs that need to be deleted. We save this as the last step
606
635
# so that we have an ensemble saved that is up to date. If we do not do so,
@@ -805,13 +834,13 @@ def candidate_selection(
805
834
806
835
def fit_ensemble (
807
836
self ,
808
- candidates : list [Run ],
809
- X_data : SUPPORTED_FEAT_TYPES ,
810
- targets : np .ndarray ,
837
+ candidates : Sequence [Run ],
838
+ runs : Sequence [Run ],
811
839
* ,
812
- runs : list [ Run ] ,
840
+ targets : np . ndarray | None = None ,
813
841
ensemble_class : Type [AbstractEnsemble ] = EnsembleSelection ,
814
- ensemble_kwargs : Dict [str , Any ] | None = None ,
842
+ ensemble_kwargs : Mapping [str , Any ] | None = None ,
843
+ X_data : SUPPORTED_FEAT_TYPES | None = None ,
815
844
task : int | None = None ,
816
845
metrics : Sequence [Scorer ] | None = None ,
817
846
precision : int | None = None ,
@@ -825,24 +854,24 @@ def fit_ensemble(
825
854
826
855
Parameters
827
856
----------
828
- candidates: list [Run]
857
+ candidates: Sequence [Run]
829
858
List of runs to build an ensemble from
830
859
831
- X_data: SUPPORTED_FEAT_TYPES
832
- The base level data.
860
+ runs: Sequence[Run]
861
+ List of all runs (also pruned ones and dummy runs)
833
862
834
- targets: np.ndarray
863
+ targets: np.ndarray | None = None
835
864
The targets to build the ensemble with
836
865
837
- runs: list[Run]
838
- List of all runs (also pruned ones and dummy runs)
839
-
840
- ensemble_class: AbstractEnsemble
866
+ ensemble_class: Type[AbstractEnsemble]
841
867
Implementation of the ensemble algorithm.
842
868
843
- ensemble_kwargs: Dict [str, Any]
869
+ ensemble_kwargs: Mapping [str, Any] | None
844
870
Arguments passed to the constructor of the ensemble algorithm.
845
871
872
+ X_data: SUPPORTED_FEAT_TYPES | None = None
873
+ The base level data.
874
+
846
875
task: int | None = None
847
876
The kind of task performed
848
877
@@ -859,24 +888,42 @@ def fit_ensemble(
859
888
-------
860
889
AbstractEnsemble
861
890
"""
862
- task = task if task is not None else self .task_type
891
+ # Validate we have targets if None specified
892
+ if targets is None :
893
+ targets = self .targets ("ensemble" )
894
+ if targets is None :
895
+ path = self .backend ._get_targets_ensemble_filename ()
896
+ raise ValueError (f"`fit_ensemble` could not find any targets at { path } " )
897
+
863
898
ensemble_class = (
864
899
ensemble_class if ensemble_class is not None else self .ensemble_class
865
900
)
866
- ensemble_kwargs = (
867
- ensemble_kwargs if ensemble_kwargs is not None else self .ensemble_kwargs
868
- )
869
- ensemble_kwargs = ensemble_kwargs if ensemble_kwargs is not None else {}
870
- metrics = metrics if metrics is not None else self .metrics
871
- rs = random_state if random_state is not None else self .random_state
872
901
873
- ensemble = ensemble_class (
874
- task_type = task ,
875
- metrics = metrics ,
876
- random_state = rs ,
877
- backend = self .backend ,
878
- ** ensemble_kwargs ,
879
- ) # type: AbstractEnsemble
902
+ # Create the ensemble_kwargs, favouring in order:
903
+ # 1) function kwargs, 2) function params 3) init_kwargs 4) init_params
904
+
905
+ # Collect func params in dict if they're not None
906
+ params = {
907
+ k : v
908
+ for k , v in [
909
+ ("task_type" , task ),
910
+ ("metrics" , metrics ),
911
+ ("random_state" , random_state ),
912
+ ]
913
+ if v is not None
914
+ }
915
+
916
+ kwargs = {
917
+ "backend" : self .backend ,
918
+ "task_type" : self .task_type ,
919
+ "metrics" : self .metrics ,
920
+ "random_state" : self .random_state ,
921
+ ** (self .ensemble_kwargs or {}),
922
+ ** params ,
923
+ ** (ensemble_kwargs or {}),
924
+ }
925
+
926
+ ensemble = ensemble_class (** kwargs ) # type: AbstractEnsemble
880
927
881
928
self .logger .debug (f"Fitting ensemble on { len (candidates )} models" )
882
929
start_time = time .time ()
@@ -995,7 +1042,8 @@ def loss(
995
1042
self ,
996
1043
run : Run ,
997
1044
metric : Scorer ,
998
- X_data : SUPPORTED_FEAT_TYPES ,
1045
+ * ,
1046
+ X_data : SUPPORTED_FEAT_TYPES | None = None ,
999
1047
kind : str = "ensemble" ,
1000
1048
) -> float :
1001
1049
"""Calculate the loss for a run
@@ -1008,6 +1056,9 @@ def loss(
1008
1056
metric: Scorer
1009
1057
The metric to calculate the loss of
1010
1058
1059
+ X_data: SUPPORTED_FEAT_TYPES | None = None
1060
+ Any X_data required to be passed to the metric
1061
+
1011
1062
kind: str = "ensemble"
1012
1063
The kind of targets to use for the run
1013
1064
0 commit comments