diff --git a/src/Microsoft.ML.AutoML/Tuner/SmacTuner.cs b/src/Microsoft.ML.AutoML/Tuner/SmacTuner.cs index 43a427c52d..3cf255ff2b 100644 --- a/src/Microsoft.ML.AutoML/Tuner/SmacTuner.cs +++ b/src/Microsoft.ML.AutoML/Tuner/SmacTuner.cs @@ -111,6 +111,12 @@ public Parameter Propose(TrialSettings settings) } } + // test purpose + internal Queue Candidates => _candidates; + + // test purpose + internal List Histories => _histories; + private FastForestRegressionModelParameters FitModel(IEnumerable history) { Single[] losses = new Single[history.Count()]; @@ -357,7 +363,10 @@ private double ComputeEI(double bestLoss, double[] forestStatistics) public void Update(TrialResult result) { - _histories.Add(result); + if (!double.IsNaN(result.Loss) && !double.IsInfinity(result.Loss)) + { + _histories.Add(result); + } } } } diff --git a/test/Microsoft.ML.AutoML.Tests/TunerTests.cs b/test/Microsoft.ML.AutoML.Tests/TunerTests.cs index 24ad39ab83..6775e61066 100644 --- a/test/Microsoft.ML.AutoML.Tests/TunerTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/TunerTests.cs @@ -66,6 +66,39 @@ public void tuner_e2e_test() } } + [Fact] + public void Smac_should_ignore_fail_trials_during_initialize() + { + // fix for https://github.com/dotnet/machinelearning-modelbuilder/issues/2721 + var context = new MLContext(1); + var searchSpace = new SearchSpace(); + var tuner = new SmacTuner(context, searchSpace, seed: 1); + for (int i = 0; i != 1000; ++i) + { + var trialSettings = new TrialSettings() + { + TrialId = i, + }; + + var param = tuner.Propose(trialSettings); + trialSettings.Parameter = param; + var option = param.AsType(); + + option.L1Regularization.Should().BeInRange(0.03125f, 32768.0f); + option.L2Regularization.Should().BeInRange(0.03125f, 32768.0f); + + tuner.Update(new TrialResult() + { + DurationInMilliseconds = i * 1000, + Loss = double.NaN, + TrialSettings = trialSettings, + }); + } + + tuner.Candidates.Count.Should().Be(0); + tuner.Histories.Count.Should().Be(0); + } + [Fact] public void CFO_should_be_recoverd_if_history_provided() {