Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Timer and ctx.CancelExecution() to fix AutoML max-time experiment bug #5445

Merged
merged 29 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d0f7054
Use ctx.CalncelExecution() to fix AutoML max-time experiment bug
mstfbl Oct 21, 2020
4fa26f8
Added unit test for checking canceled experiment
mstfbl Oct 21, 2020
48a6267
Nit fix
mstfbl Oct 21, 2020
f324030
Different run time on Linux
mstfbl Oct 22, 2020
ee70024
Review
mstfbl Oct 22, 2020
36bf24e
Testing four ouput
mstfbl Oct 22, 2020
d5d23de
Used reflection to test for contexts being canceled
mstfbl Oct 23, 2020
33cf5a6
Reviews
mstfbl Oct 26, 2020
bfc93e9
Merge remote-tracking branch 'upstream/master' into issue5437
mstfbl Oct 26, 2020
c69a19f
Reviews
mstfbl Oct 28, 2020
299b05b
Added main MLContext listener-timer
mstfbl Oct 29, 2020
2e2d441
Merge remote-tracking branch 'upstream/master' into issue5437
mstfbl Oct 29, 2020
ce747fb
Added PRNG on _context, held onto timers for avoiding GC
mstfbl Oct 30, 2020
7635500
Addressed reviews
mstfbl Oct 30, 2020
94a80de
Unit test edits
mstfbl Oct 30, 2020
abe1d7f
Increase run time of experiment to guarantee probabilities
mstfbl Oct 30, 2020
9585a50
Edited unit test to check produced schema of next run model's predict…
mstfbl Oct 30, 2020
1ab662f
Remove scheme check as different CI builds result in varying schemas
mstfbl Oct 30, 2020
bc9e578
Decrease max experiment time unit test time
mstfbl Oct 30, 2020
71ebf23
Merged with master
mstfbl Oct 31, 2020
2d8d06f
Added Timers
mstfbl Nov 2, 2020
490d8c1
Increase second timer time, edit unit test
mstfbl Nov 2, 2020
b0de1d3
Added try catch for OperationCanceledException in Execute()
mstfbl Nov 3, 2020
0918afa
Add AggregateException try catch to slow unit tests for parallel testing
mstfbl Nov 3, 2020
0922aed
Reviews
mstfbl Nov 3, 2020
ef4b34f
Final reviews
mstfbl Nov 3, 2020
b4b49ce
Added LightGBMFact to binary classification test
mstfbl Nov 3, 2020
6502fc8
Removed extra Operation Stopped exception try catch
mstfbl Nov 3, 2020
28e2f2e
Add back OperationCanceledException to Experiment.cs
mstfbl Nov 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions src/Microsoft.ML.AutoML/Experiment/Experiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.AutoML
Expand All @@ -25,6 +27,11 @@ internal class Experiment<TRunDetail, TMetrics> where TRunDetail : RunDetail
private readonly IRunner<TRunDetail> _runner;
private readonly IList<SuggestedPipelineRunDetail> _history;
private readonly IChannel _logger;
private Timer _maxExperimentTimeTimer;
private Timer _mainContextCanceledTimer;
private bool _experimentTimerExpired;
private MLContext _currentModelMLContext;
private Random _newContextSeedGenerator;

public Experiment(MLContext context,
TaskKind task,
Expand All @@ -49,23 +56,76 @@ public Experiment(MLContext context,
_datasetColumnInfo = datasetColumnInfo;
_runner = runner;
_logger = logger;
_experimentTimerExpired = false;
}

private void MaxExperimentTimeExpiredEvent(object state)
{
// If at least one model was run, end experiment immediately.
// Else, wait for first model to run before experiment is concluded.
_experimentTimerExpired = true;
if (_history.Any(r => r.RunSucceeded))
{
_logger.Warning("Allocated time for Experiment of {0} seconds has elapsed with {1} models run. Ending experiment...",
_experimentSettings.MaxExperimentTimeInSeconds, _history.Count());
_currentModelMLContext.CancelExecution();
}
}

private void MainContextCanceledEvent(object state)
{
// If the main MLContext is canceled, cancel the ongoing model training and MLContext.
if ((_context.Model.GetEnvironment() as ICancelable).IsCanceled)
{
_logger.Warning("Main MLContext has been canceled. Ending experiment...");
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
_currentModelMLContext.CancelExecution();
}
}

public IList<TRunDetail> Execute()
{
var stopwatch = Stopwatch.StartNew();
var iterationResults = new List<TRunDetail>();
// Create a timer for the max duration of experiment. When given time has
// elapsed, MaxExperimentTimeExpiredEvent is called to interrupt training
// of current model. Timer is not used if no experiment time is given, or
// is not a positive number.
if (_experimentSettings.MaxExperimentTimeInSeconds > 0)
{
_maxExperimentTimeTimer = new Timer(
new TimerCallback(MaxExperimentTimeExpiredEvent), null,
_experimentSettings.MaxExperimentTimeInSeconds * 1000, Timeout.Infinite
);
}
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
// If given max duration of experiment is 0, only 1 model will be trained.
// _experimentSettings.MaxExperimentTimeInSeconds is of type uint, it is
// either 0 or >0.
else
_experimentTimerExpired = true;

// Add second timer to check for the cancelation signal from the main MLContext
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
// to the active child MLContext. This timer will propagate the cancelation
// signal from the main to the child MLContexs if the main MLContext is
// canceled.
_mainContextCanceledTimer = new Timer(new TimerCallback(MainContextCanceledEvent), null, 1000, 1000);

// Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to
// maintain variability between training iterations.
int? mainContextSeed = ((ISeededEnvironment)_context.Model.GetEnvironment()).Seed;
_newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : RandomUtils.Create();
mstfbl marked this conversation as resolved.
Show resolved Hide resolved

do
{
var iterationStopwatch = Stopwatch.StartNew();

// get next pipeline
var getPipelineStopwatch = Stopwatch.StartNew();
var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task,
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _trainerAllowList);

var pipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;
// A new MLContext is needed per model run. When max experiment time is reached, each used
// context is canceled to stop further model training. The cancellation of the main MLContext
// a user has instantiated is not desirable, thus additional MLContexts are used.
_currentModelMLContext = _newContextSeedGenerator == null ? new MLContext() : new MLContext(_newContextSeedGenerator.Next());
var pipeline = PipelineSuggester.GetNextInferredPipeline(_currentModelMLContext, _history, _datasetColumnInfo, _task,
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _trainerAllowList);

// break if no candidates returned, means no valid pipeline available
if (pipeline == null)
Expand Down Expand Up @@ -101,8 +161,7 @@ public IList<TRunDetail> Execute()

} while (_history.Count < _experimentSettings.MaxModels &&
!_experimentSettings.CancellationToken.IsCancellationRequested &&
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds);

!_experimentTimerExpired);
return iterationResults;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public CrossValSummaryRunner(MLContext context,
for (var i = 0; i < _trainDatasets.Length; i++)
{
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
var trainResult = RunnerUtil.TrainAndScorePipeline(pipeline.GetContext(), pipeline, _trainDatasets[i], _validDatasets[i],
_groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
_logger);
trainResults.Add(trainResult);
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.ML.AutoML/Experiment/SuggestedPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public override int GetHashCode()
return ToString().GetHashCode();
}

public MLContext GetContext()
{
return _context;
}

public Pipeline ToPipeline()
{
var pipelineElements = new List<PipelineNode>();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
internal interface ICancelable
{
/// <summary>
/// Signal to stop exection in all the hosts.
/// Signal to stop execution in all the hosts.
/// </summary>
void CancelExecution();

Expand Down
39 changes: 35 additions & 4 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using System.Reflection;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers.LightGbm;
using Xunit;
using Xunit.Abstractions;
using static Microsoft.ML.DataOperationsCatalog;
Expand Down Expand Up @@ -117,7 +119,7 @@ public void AutoFitRegressionTest()
.Execute(trainData, validationData,
new ColumnInformation() { LabelColumnName = DatasetUtil.MlNetGeneratedRegressionLabel });

Assert.True(result.RunDetails.Max(i => i.ValidationMetrics.RSquared > 0.9));
Assert.True(result.RunDetails.Max(i => i?.ValidationMetrics?.RSquared) > 0.9);
}

[LightGBMFact]
Expand Down Expand Up @@ -165,7 +167,7 @@ public void AutoFitRankingTest()
Assert.True(experimentResults[i].RunDetails.Count() > 0);
Assert.NotNull(bestRun.ValidationMetrics);
Assert.True(bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Last() > 0.4);
Assert.True(bestRun.ValidationMetrics.DiscountedCumulativeGains.Last() > 20);
Assert.True(bestRun.ValidationMetrics.DiscountedCumulativeGains.Last() > 19);
var outputSchema = bestRun.Model.GetOutputSchema(trainDataView.Schema);
var expectedOutputNames = new string[] { labelColumnName, groupIdColumnName, groupIdColumnName, featuresColumnVectorNameA, featuresColumnVectorNameB,
"Features", scoreColumnName };
Expand Down Expand Up @@ -246,7 +248,7 @@ public void AutoFitRecommendationTest()
RunDetail<RegressionMetrics> bestRun = experimentResult.BestRun;
Assert.True(experimentResult.RunDetails.Count() > 1);
Assert.NotNull(bestRun.ValidationMetrics);
Assert.True(experimentResult.RunDetails.Max(i => i.ValidationMetrics.RSquared != 0));
Assert.True(experimentResult.RunDetails.Max(i => i?.ValidationMetrics?.RSquared* i?.ValidationMetrics?.RSquared) > 0.5);

var outputSchema = bestRun.Model.GetOutputSchema(trainDataView.Schema);
var expectedOutputNames = new string[] { labelColumnName, userColumnName, userColumnName, itemColumnName, itemColumnName, scoreColumnName };
Expand Down Expand Up @@ -320,6 +322,35 @@ public void AutoFitWithPresplittedData()

}

[Fact]
public void AutoFitMaxExperimentTimeTest()
{
// A single binary classification experiment takes less than 5 seconds.
// System.OperationCanceledException is thrown when ongoing experiment
// is canceled and at least one model has been generated.
var context = new MLContext(1);
var dataPath = DatasetUtil.GetUciAdultDataset();
var columnInference = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel);
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
var trainData = textLoader.Load(dataPath);
var experiment = context.Auto()
.CreateBinaryClassificationExperiment(10)
.Execute(trainData, new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel });

// Ensure the (last) model that was training when maximum experiment time was reached has been stopped,
// and that its MLContext has been canceled.
Assert.True(experiment.RunDetails.Last().Exception.Message.Contains("Operation was canceled"),
"Training process was not successfully canceled after maximum experiment time was reached.");

// Ensure that the best found model can still run after maximum experiment time was reached.
var refitModel = experiment.BestRun.Estimator.Fit(trainData);
IDataView predictions = refitModel.Transform(trainData);
var prev = predictions.Preview();
Assert.Equal(30, predictions.Schema.Count);
Assert.True(predictions.Schema.GetColumnOrNull("PredictedLabel").HasValue);
Assert.True(predictions.Schema.GetColumnOrNull("Score").HasValue);
}

private TextLoader.Options GetLoaderArgs(string labelColumnName, string userIdColumnName, string itemIdColumnName)
{
return new TextLoader.Options()
Expand Down