-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve RegressionExpeirment using AutoMLExperiment #6338
Merged
LittleLittleCloud
merged 4 commits into
dotnet:main
from
LittleLittleCloud:u/xiaoyun/autoRegression
Sep 29, 2022
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,13 @@ | |
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Diagnostics; | ||
using System.Linq; | ||
using System.Threading.Tasks; | ||
using System.Threading; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Trainers; | ||
using Microsoft.ML.Trainers.FastTree; | ||
using Microsoft.ML.Trainers.LightGbm; | ||
|
@@ -92,16 +97,6 @@ public enum RegressionTrainer | |
/// </summary> | ||
LightGbm, | ||
|
||
/// <summary> | ||
/// See <see cref="OnlineGradientDescentTrainer"/>. | ||
/// </summary> | ||
OnlineGradientDescent, | ||
|
||
/// <summary> | ||
/// See <see cref="OlsTrainer"/>. | ||
/// </summary> | ||
Ols, | ||
|
||
/// <summary> | ||
/// See <see cref="LbfgsPoissonRegressionTrainer"/>. | ||
/// </summary> | ||
|
@@ -124,6 +119,10 @@ public enum RegressionTrainer | |
/// </example> | ||
public sealed class RegressionExperiment : ExperimentBase<RegressionMetrics, RegressionExperimentSettings> | ||
{ | ||
private readonly AutoMLExperiment _experiment; | ||
private const string Features = "__Features__"; | ||
private SweepablePipeline _pipeline; | ||
|
||
internal RegressionExperiment(MLContext context, RegressionExperimentSettings settings) | ||
: base(context, | ||
new RegressionMetricsAgent(context, settings.OptimizingMetric), | ||
|
@@ -132,6 +131,187 @@ internal RegressionExperiment(MLContext context, RegressionExperimentSettings se | |
TaskKind.Regression, | ||
TrainerExtensionUtil.GetTrainerNames(settings.Trainers)) | ||
{ | ||
_experiment = context.Auto().CreateExperiment(); | ||
|
||
if (settings.MaximumMemoryUsageInMegaByte is double d) | ||
{ | ||
_experiment.SetMaximumMemoryUsageInMegaByte(d); | ||
} | ||
|
||
_experiment.SetTrainingTimeInSeconds(Settings.MaxExperimentTimeInSeconds); | ||
} | ||
|
||
public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<RegressionMetrics>> progressHandler = null) | ||
{ | ||
var label = columnInformation.LabelColumnName; | ||
_experiment.SetRegressionMetric(Settings.OptimizingMetric, label); | ||
|
||
// Cross val threshold for # of dataset rows -- | ||
// If dataset has < threshold # of rows, use cross val. | ||
// Else, run experiment using train-validate split. | ||
const int crossValRowCountThreshold = 15000; | ||
var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold); | ||
// TODO | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep |
||
// split cross validation result according to sample key as well. | ||
if (rowCount < crossValRowCountThreshold) | ||
{ | ||
int numCrossValFolds = 10; | ||
_experiment.SetDataset(trainData, numCrossValFolds); | ||
_pipeline = CreateRegressionPipeline(trainData, columnInformation, preFeaturizer); | ||
|
||
TrialResultMonitor<RegressionMetrics> monitor = null; | ||
_experiment.SetMonitor((provider) => | ||
{ | ||
var channel = provider.GetService<IChannel>(); | ||
var pipeline = provider.GetService<SweepablePipeline>(); | ||
monitor = new TrialResultMonitor<RegressionMetrics>(channel, pipeline); | ||
monitor.OnTrialCompleted += (o, e) => | ||
{ | ||
var detail = BestResultUtil.ToRunDetail(Context, e, _pipeline); | ||
progressHandler?.Report(detail); | ||
}; | ||
|
||
return monitor; | ||
}); | ||
|
||
_experiment.SetTrialRunner<RegressionTrialRunner>(); | ||
_experiment.Run(); | ||
|
||
var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToRunDetail(Context, e, _pipeline)); | ||
var bestRun = BestResultUtil.ToRunDetail(Context, monitor.BestRun, _pipeline); | ||
var result = new ExperimentResult<RegressionMetrics>(runDetails, bestRun); | ||
|
||
return result; | ||
} | ||
else | ||
{ | ||
var splitData = Context.Data.TrainTestSplit(trainData); | ||
return Execute(splitData.TrainSet, splitData.TestSet, columnInformation, preFeaturizer, progressHandler); | ||
} | ||
} | ||
|
||
public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<RegressionMetrics>> progressHandler = null) | ||
{ | ||
var label = columnInformation.LabelColumnName; | ||
_experiment.SetRegressionMetric(Settings.OptimizingMetric, label); | ||
_experiment.SetDataset(trainData, validationData); | ||
|
||
_pipeline = CreateRegressionPipeline(trainData, columnInformation, preFeaturizer); | ||
_experiment.SetPipeline(_pipeline); | ||
|
||
// set monitor | ||
TrialResultMonitor<RegressionMetrics> monitor = null; | ||
_experiment.SetMonitor((provider) => | ||
{ | ||
var channel = provider.GetService<IChannel>(); | ||
var pipeline = provider.GetService<SweepablePipeline>(); | ||
monitor = new TrialResultMonitor<RegressionMetrics>(channel, pipeline); | ||
monitor.OnTrialCompleted += (o, e) => | ||
{ | ||
var detail = BestResultUtil.ToRunDetail(Context, e, _pipeline); | ||
progressHandler?.Report(detail); | ||
}; | ||
|
||
return monitor; | ||
}); | ||
|
||
_experiment.SetTrialRunner<RegressionTrialRunner>(); | ||
_experiment.Run(); | ||
|
||
var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToRunDetail(Context, e, _pipeline)); | ||
var bestRun = BestResultUtil.ToRunDetail(Context, monitor.BestRun, _pipeline); | ||
var result = new ExperimentResult<RegressionMetrics>(runDetails, bestRun); | ||
|
||
return result; | ||
} | ||
|
||
public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData, IDataView validationData, string labelColumnName = "Label", IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<RegressionMetrics>> progressHandler = null) | ||
{ | ||
var columnInformation = new ColumnInformation() | ||
{ | ||
LabelColumnName = labelColumnName, | ||
}; | ||
|
||
return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler); | ||
} | ||
|
||
public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData, string labelColumnName = "Label", string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<RegressionMetrics>> progressHandler = null) | ||
{ | ||
var columnInformation = new ColumnInformation() | ||
{ | ||
LabelColumnName = labelColumnName, | ||
SamplingKeyColumnName = samplingKeyColumn, | ||
}; | ||
|
||
return Execute(trainData, columnInformation, preFeaturizer, progressHandler); | ||
} | ||
|
||
public override CrossValidationExperimentResult<RegressionMetrics> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetail<RegressionMetrics>> progressHandler = null) | ||
{ | ||
var label = columnInformation.LabelColumnName; | ||
_experiment.SetRegressionMetric(Settings.OptimizingMetric, label); | ||
_experiment.SetDataset(trainData, (int)numberOfCVFolds); | ||
|
||
_pipeline = CreateRegressionPipeline(trainData, columnInformation, preFeaturizer); | ||
_experiment.SetPipeline(_pipeline); | ||
|
||
// set monitor | ||
TrialResultMonitor<RegressionMetrics> monitor = null; | ||
_experiment.SetMonitor((provider) => | ||
{ | ||
var channel = provider.GetService<IChannel>(); | ||
var pipeline = provider.GetService<SweepablePipeline>(); | ||
monitor = new TrialResultMonitor<RegressionMetrics>(channel, pipeline); | ||
monitor.OnTrialCompleted += (o, e) => | ||
{ | ||
var detail = BestResultUtil.ToCrossValidationRunDetail(Context, e, _pipeline); | ||
progressHandler?.Report(detail); | ||
}; | ||
|
||
return monitor; | ||
}); | ||
|
||
_experiment.SetTrialRunner<RegressionTrialRunner>(); | ||
_experiment.Run(); | ||
|
||
var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToCrossValidationRunDetail(Context, e, _pipeline)); | ||
var bestResult = BestResultUtil.ToCrossValidationRunDetail(Context, monitor.BestRun, _pipeline); | ||
|
||
var result = new CrossValidationExperimentResult<RegressionMetrics>(runDetails, bestResult); | ||
|
||
return result; | ||
} | ||
|
||
public override CrossValidationExperimentResult<RegressionMetrics> Execute(IDataView trainData, uint numberOfCVFolds, string labelColumnName = "Label", string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetail<RegressionMetrics>> progressHandler = null) | ||
{ | ||
var columnInformation = new ColumnInformation() | ||
{ | ||
LabelColumnName = labelColumnName, | ||
SamplingKeyColumnName = samplingKeyColumn, | ||
}; | ||
|
||
return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizer, progressHandler); | ||
} | ||
|
||
private SweepablePipeline CreateRegressionPipeline(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null) | ||
{ | ||
var useSdca = Settings.Trainers.Contains(RegressionTrainer.StochasticDualCoordinateAscent); | ||
var uselbfgs = Settings.Trainers.Contains(RegressionTrainer.LbfgsPoissonRegression); | ||
var useLgbm = Settings.Trainers.Contains(RegressionTrainer.LightGbm); | ||
var useFastForest = Settings.Trainers.Contains(RegressionTrainer.FastForest); | ||
var useFastTree = Settings.Trainers.Contains(RegressionTrainer.FastTree) || Settings.Trainers.Contains(RegressionTrainer.FastTreeTweedie); | ||
|
||
SweepablePipeline pipeline = new SweepablePipeline(); | ||
if (preFeaturizer != null) | ||
{ | ||
pipeline = pipeline.Append(preFeaturizer); | ||
} | ||
|
||
var label = columnInformation.LabelColumnName; | ||
pipeline = pipeline.Append(Context.Auto().Featurizer(trainData, columnInformation, Features)); | ||
pipeline = pipeline.Append(Context.Auto().Regression(label, useSdca: useSdca, useFastTree: useFastTree, useLgbm: useLgbm, useLbfgs: uselbfgs, useFastForest: useFastForest, featureColumnName: Features)); | ||
|
||
return pipeline; | ||
} | ||
|
||
private protected override CrossValidationRunDetail<RegressionMetrics> GetBestCrossValRun(IEnumerable<CrossValidationRunDetail<RegressionMetrics>> results) | ||
|
@@ -176,4 +356,124 @@ public static CrossValidationRunDetail<RegressionMetrics> Best(this IEnumerable< | |
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); | ||
} | ||
} | ||
|
||
internal class RegressionTrialRunner : ITrialRunner | ||
{ | ||
private MLContext _context; | ||
private readonly IDatasetManager _datasetManager; | ||
private readonly IMetricManager _metricManager; | ||
private readonly SweepablePipeline _pipeline; | ||
private readonly Random _rnd; | ||
|
||
public RegressionTrialRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) | ||
{ | ||
_context = context; | ||
_datasetManager = datasetManager; | ||
_metricManager = metricManager; | ||
_pipeline = pipeline; | ||
_rnd = settings.Seed.HasValue ? new Random(settings.Seed.Value) : new Random(); | ||
} | ||
|
||
public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct) | ||
{ | ||
try | ||
{ | ||
using (var ctRegistration = ct.Register(() => | ||
{ | ||
_context?.CancelExecution(); | ||
})) | ||
{ | ||
if (_metricManager is RegressionMetricManager metricManager) | ||
{ | ||
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName]; | ||
var pipeline = _pipeline.BuildFromOption(_context, parameter); | ||
if (_datasetManager is ICrossValidateDatasetManager datasetManager) | ||
{ | ||
var stopWatch = new Stopwatch(); | ||
stopWatch.Start(); | ||
var fold = datasetManager.Fold ?? 5; | ||
var metrics = _context.Regression.CrossValidate(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn); | ||
|
||
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best. | ||
var res = metrics[_rnd.Next(fold)]; | ||
var model = res.Model; | ||
var metric = metricManager.Metric switch | ||
{ | ||
RegressionMetric.RootMeanSquaredError => res.Metrics.RootMeanSquaredError, | ||
RegressionMetric.RSquared => res.Metrics.RSquared, | ||
RegressionMetric.MeanSquaredError => res.Metrics.MeanSquaredError, | ||
RegressionMetric.MeanAbsoluteError => res.Metrics.MeanAbsoluteError, | ||
_ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"), | ||
}; | ||
var loss = metricManager.IsMaximize ? -metric : metric; | ||
|
||
stopWatch.Stop(); | ||
|
||
|
||
return Task.FromResult(new TrialResult<RegressionMetrics>() | ||
{ | ||
Loss = loss, | ||
Metric = metric, | ||
Model = model, | ||
TrialSettings = settings, | ||
DurationInMilliseconds = stopWatch.ElapsedMilliseconds, | ||
Metrics = res.Metrics, | ||
CrossValidationMetrics = metrics, | ||
Pipeline = pipeline, | ||
} as TrialResult); | ||
} | ||
|
||
if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager) | ||
{ | ||
var stopWatch = new Stopwatch(); | ||
stopWatch.Start(); | ||
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); | ||
var eval = model.Transform(trainTestDatasetManager.TestDataset); | ||
var res = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn); | ||
|
||
var metric = metricManager.Metric switch | ||
{ | ||
RegressionMetric.RootMeanSquaredError => res.RootMeanSquaredError, | ||
RegressionMetric.RSquared => res.RSquared, | ||
RegressionMetric.MeanSquaredError => res.MeanSquaredError, | ||
RegressionMetric.MeanAbsoluteError => res.MeanAbsoluteError, | ||
_ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"), | ||
}; | ||
var loss = metricManager.IsMaximize ? -metric : metric; | ||
|
||
stopWatch.Stop(); | ||
|
||
|
||
return Task.FromResult(new TrialResult<RegressionMetrics>() | ||
{ | ||
Loss = loss, | ||
Metric = metric, | ||
Model = model, | ||
TrialSettings = settings, | ||
DurationInMilliseconds = stopWatch.ElapsedMilliseconds, | ||
Metrics = res, | ||
Pipeline = pipeline, | ||
} as TrialResult); | ||
} | ||
} | ||
|
||
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); | ||
} | ||
} | ||
catch (Exception ex) when (ct.IsCancellationRequested) | ||
{ | ||
throw new OperationCanceledException(ex.Message, ex.InnerException); | ||
} | ||
catch (Exception) | ||
{ | ||
throw; | ||
} | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
_context.CancelExecution(); | ||
_context = null; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you may add the parameter name before
10
for code readability :-)