Skip to content

Commit

Permalink
make AutoMLExperiment public && some small refactor (#6173)
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud authored Apr 27, 2022
1 parent 0e7e807 commit 9336dae
Show file tree
Hide file tree
Showing 17 changed files with 234 additions and 77 deletions.
80 changes: 75 additions & 5 deletions src/Microsoft.ML.AutoML/API/AutoCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.ML.AutoML.CodeGen;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;
using Microsoft.ML.Trainers.FastTree;

namespace Microsoft.ML.AutoML
{
Expand Down Expand Up @@ -286,18 +287,43 @@ public ColumnInferenceResults InferColumns(string path, uint labelColumnIndex, b
/// <summary>
/// Create a sweepable estimator with a custom factory and search space.
/// </summary>
internal SweepableEstimator CreateSweepableEstimator<T>(Func<MLContext, T, IEstimator<ITransformer>> factory, SearchSpace<T> ss = null)
public SweepableEstimator CreateSweepableEstimator<T>(Func<MLContext, T, IEstimator<ITransformer>> factory, SearchSpace<T> ss = null)
where T : class, new()
{
return new SweepableEstimator((MLContext context, Parameter param) => factory(context, param.AsType<T>()), ss);
}

internal AutoMLExperiment CreateExperiment()
/// <summary>
/// Create an <see cref="AutoMLExperiment"/>.
/// </summary>
public AutoMLExperiment CreateExperiment()
{
return new AutoMLExperiment(_context, new AutoMLExperiment.AutoMLExperimentSettings());
}

internal SweepableEstimator[] BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
/// <summary>
/// Create a list of <see cref="SweepableEstimator"/> for binary classification.
/// </summary>
/// <param name="labelColumnName">label column name.</param>
/// <param name="featureColumnName">feature column name.</param>
/// <param name="exampleWeightColumnName">example weight column name.</param>
/// <param name="useFastForest">true if use fast forest as available trainer.</param>
/// <param name="useLgbm">true if use lgbm as available trainer.</param>
/// <param name="useFastTree">true if use fast tree as available trainer.</param>
/// <param name="useLbfgs">true if use lbfgs as available trainer.</param>
/// <param name="useSdca">true if use sdca as available trainer.</param>
/// <param name="fastTreeOption">if provided, use it as initial option for fast tree, otherwise the default option will be used.</param>
/// <param name="lgbmOption">if provided, use it as initial option for lgbm, otherwise the default option will be used.</param>
/// <param name="fastForestOption">if provided, use it as initial option for fast forest, otherwise the default option will be used.</param>
/// <param name="lbfgsOption">if provided, use it as initial option for lbfgs, otherwise the default option will be used.</param>
/// <param name="sdcaOption">if provided, use it as initial option for sdca, otherwise the default option will be used.</param>
/// <param name="fastTreeSearchSpace">if provided, use it as search space for fast tree, otherwise the default search space will be used.</param>
/// <param name="lgbmSearchSpace">if provided, use it as search space for lgbm, otherwise the default search space will be used.</param>
/// <param name="fastForestSearchSpace">if provided, use it as search space for fast forest, otherwise the default search space will be used.</param>
/// <param name="lbfgsSearchSpace">if provided, use it as search space for lbfgs, otherwise the default search space will be used.</param>
/// <param name="sdcaSearchSpace">if provided, use it as search space for sdca, otherwise the default search space will be used.</param>
/// <returns></returns>
public SweepableEstimator[] BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
{
Expand Down Expand Up @@ -351,7 +377,29 @@ internal SweepableEstimator[] BinaryClassification(string labelColumnName = Defa
return res.ToArray();
}

internal SweepableEstimator[] MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
/// <summary>
/// Create a list of <see cref="SweepableEstimator"/> for multiclass classification.
/// </summary>
/// <param name="labelColumnName">label column name.</param>
/// <param name="featureColumnName">feature column name.</param>
/// <param name="exampleWeightColumnName">example weight column name.</param>
/// <param name="useFastForest">true if use fast forest as available trainer.</param>
/// <param name="useLgbm">true if use lgbm as available trainer.</param>
/// <param name="useFastTree">true if use fast tree as available trainer.</param>
/// <param name="useLbfgs">true if use lbfgs as available trainer.</param>
/// <param name="useSdca">true if use sdca as available trainer.</param>
/// <param name="fastTreeOption">if provided, use it as initial option for fast tree, otherwise the default option will be used.</param>
/// <param name="lgbmOption">if provided, use it as initial option for lgbm, otherwise the default option will be used.</param>
/// <param name="fastForestOption">if provided, use it as initial option for fast forest, otherwise the default option will be used.</param>
/// <param name="lbfgsOption">if provided, use it as initial option for lbfgs, otherwise the default option will be used.</param>
/// <param name="sdcaOption">if provided, use it as initial option for sdca, otherwise the default option will be used.</param>
/// <param name="fastTreeSearchSpace">if provided, use it as search space for fast tree, otherwise the default search space will be used.</param>
/// <param name="lgbmSearchSpace">if provided, use it as search space for lgbm, otherwise the default search space will be used.</param>
/// <param name="fastForestSearchSpace">if provided, use it as search space for fast forest, otherwise the default search space will be used.</param>
/// <param name="lbfgsSearchSpace">if provided, use it as search space for lbfgs, otherwise the default search space will be used.</param>
/// <param name="sdcaSearchSpace">if provided, use it as search space for sdca, otherwise the default search space will be used.</param>
/// <returns></returns>
public SweepableEstimator[] MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
{
Expand Down Expand Up @@ -407,7 +455,29 @@ internal SweepableEstimator[] MultiClassification(string labelColumnName = Defau
return res.ToArray();
}

internal SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
/// <summary>
/// Create a list of <see cref="SweepableEstimator"/> for regression.
/// </summary>
/// <param name="labelColumnName">label column name.</param>
/// <param name="featureColumnName">feature column name.</param>
/// <param name="exampleWeightColumnName">example weight column name.</param>
/// <param name="useFastForest">true if use fast forest as available trainer.</param>
/// <param name="useLgbm">true if use lgbm as available trainer.</param>
/// <param name="useFastTree">true if use fast tree as available trainer.</param>
/// <param name="useLbfgs">true if use lbfgs as available trainer.</param>
/// <param name="useSdca">true if use sdca as available trainer.</param>
/// <param name="fastTreeOption">if provided, use it as initial option for fast tree, otherwise the default option will be used.</param>
/// <param name="lgbmOption">if provided, use it as initial option for lgbm, otherwise the default option will be used.</param>
/// <param name="fastForestOption">if provided, use it as initial option for fast forest, otherwise the default option will be used.</param>
/// <param name="lbfgsOption">if provided, use it as initial option for lbfgs, otherwise the default option will be used.</param>
/// <param name="sdcaOption">if provided, use it as initial option for sdca, otherwise the default option will be used.</param>
/// <param name="fastTreeSearchSpace">if provided, use it as search space for fast tree, otherwise the default search space will be used.</param>
/// <param name="lgbmSearchSpace">if provided, use it as search space for lgbm, otherwise the default search space will be used.</param>
/// <param name="fastForestSearchSpace">if provided, use it as search space for fast forest, otherwise the default search space will be used.</param>
/// <param name="lbfgsSearchSpace">if provided, use it as search space for lbfgs, otherwise the default search space will be used.</param>
/// <param name="sdcaSearchSpace">if provided, use it as search space for sdca, otherwise the default search space will be used.</param>
/// <returns></returns>
public SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.AutoML/API/SweepableExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Microsoft.ML.AutoML
{
internal static class SweepableExtension
public static class SweepableExtension
{
public static SweepableEstimatorPipeline Append(this IEstimator<ITransformer> estimator, SweepableEstimator estimator1)
{
Expand Down
45 changes: 29 additions & 16 deletions src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Microsoft.ML.AutoML
{
internal class AutoMLExperiment
public class AutoMLExperiment
{
private readonly AutoMLExperimentSettings _settings;
private readonly MLContext _context;
Expand Down Expand Up @@ -52,12 +52,14 @@ public AutoMLExperiment SetTrainingTimeInSeconds(uint trainingTimeInSeconds)

public AutoMLExperiment SetDataset(IDataView train, IDataView test)
{
_settings.DatasetSettings = new TrainTestDatasetSettings()
var datasetManager = new TrainTestDatasetManager()
{
TrainDataset = train,
TestDataset = test
};

_serviceCollection.AddSingleton<IDatasetManager>(datasetManager);

return this;
}

Expand All @@ -70,12 +72,14 @@ public AutoMLExperiment SetDataset(TrainTestData trainTestSplit)

public AutoMLExperiment SetDataset(IDataView dataset, int fold = 10)
{
_settings.DatasetSettings = new CrossValidateDatasetSettings()
var datasetManager = new CrossValidateDatasetManager()
{
Dataset = dataset,
Fold = fold,
};

_serviceCollection.AddSingleton<IDatasetManager>(datasetManager);

return this;
}

Expand Down Expand Up @@ -116,8 +120,15 @@ public AutoMLExperiment SetPipeline(MultiModelPipeline pipeline)
return this;
}

public AutoMLExperiment SetTrialRunnerFactory(ITrialRunnerFactory factory)
public AutoMLExperiment SetIsMaximizeMetric(bool isMaximize)
{
_settings.IsMaximizeMetric = isMaximize;
return this;
}

public AutoMLExperiment SetTrialRunner(ITrialRunner runner)
{
var factory = new CustomRunnerFactory(runner);
var descriptor = new ServiceDescriptor(typeof(ITrialRunnerFactory), factory);
if (_serviceCollection.Contains(descriptor))
{
Expand Down Expand Up @@ -146,36 +157,42 @@ public AutoMLExperiment SetPipeline(SweepableEstimatorPipeline pipeline)

public AutoMLExperiment SetEvaluateMetric(BinaryClassificationMetric metric, string labelColumn = "label", string predictedColumn = "Predicted")
{
_settings.EvaluateMetric = new BinaryMetricSettings()
var metricManager = new BinaryMetricManager()
{
Metric = metric,
PredictedColumn = predictedColumn,
LabelColumn = labelColumn,
};
_serviceCollection.AddSingleton<IMetricManager>(metricManager);
SetIsMaximizeMetric(metricManager.IsMaximize);

return this;
}

public AutoMLExperiment SetEvaluateMetric(MulticlassClassificationMetric metric, string labelColumn = "label", string predictedColumn = "Predicted")
{
_settings.EvaluateMetric = new MultiClassMetricSettings()
var metricManager = new MultiClassMetricManager()
{
Metric = metric,
PredictedColumn = predictedColumn,
LabelColumn = labelColumn,
};
_serviceCollection.AddSingleton<IMetricManager>(metricManager);
SetIsMaximizeMetric(metricManager.IsMaximize);

return this;
}

public AutoMLExperiment SetEvaluateMetric(RegressionMetric metric, string labelColumn = "label", string scoreColumn = "Score")
{
_settings.EvaluateMetric = new RegressionMetricSettings()
var metricManager = new RegressionMetricManager()
{
Metric = metric,
ScoreColumn = scoreColumn,
LabelColumn = labelColumn,
};
_serviceCollection.AddSingleton<IMetricManager>(metricManager);
SetIsMaximizeMetric(metricManager.IsMaximize);

return this;
}
Expand Down Expand Up @@ -224,13 +241,13 @@ private async Task<TrialResult> RunAsync(CancellationToken ct)
setting = pipelineProposer.Propose(setting);
setting = hyperParameterProposer.Propose(setting);
monitor.ReportRunningTrial(setting);
var runner = runnerFactory.CreateTrialRunner(setting);
var trialResult = runner.Run(setting);
var runner = runnerFactory.CreateTrialRunner();
var trialResult = runner.Run(setting, serviceProvider);
monitor.ReportCompletedTrial(trialResult);
hyperParameterProposer.Update(setting, trialResult);
pipelineProposer.Update(setting, trialResult);

var error = _settings.EvaluateMetric.IsMaximize ? 1 - trialResult.Metric : trialResult.Metric;
var error = _settings.IsMaximizeMetric ? 1 - trialResult.Metric : trialResult.Metric;
if (error < _bestError)
{
_bestTrialResult = trialResult;
Expand Down Expand Up @@ -264,20 +281,16 @@ private async Task<TrialResult> RunAsync(CancellationToken ct)
private void ValidateSettings()
{
Contracts.Assert(_settings.MaxExperimentTimeInSeconds > 0, $"{nameof(ExperimentSettings.MaxExperimentTimeInSeconds)} must be larger than 0");
Contracts.Assert(_settings.DatasetSettings != null, $"{nameof(_settings.DatasetSettings)} must be not null");
Contracts.Assert(_settings.EvaluateMetric != null, $"{nameof(_settings.EvaluateMetric)} must be not null");
}


public class AutoMLExperimentSettings : ExperimentSettings
{
public IDatasetSettings DatasetSettings { get; set; }

public IMetricSettings EvaluateMetric { get; set; }

public MultiModelPipeline Pipeline { get; set; }

public int? Seed { get; set; }

public bool IsMaximizeMetric { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@

namespace Microsoft.ML.AutoML
{
internal interface IDatasetSettings
/// <summary>
/// Interface for dataset manager. This interface doesn't include any method or property definition and is used by <see cref="AutoMLExperiment"/> and other components to retrieve the instance of the actual
/// dataset manager from containers.
/// </summary>
public interface IDatasetManager
{
}

internal class TrainTestDatasetSettings : IDatasetSettings
public class TrainTestDatasetManager : IDatasetManager
{
public IDataView TrainDataset { get; set; }

public IDataView TestDataset { get; set; }
}

internal class CrossValidateDatasetSettings : IDatasetSettings
public class CrossValidateDatasetManager : IDatasetManager
{
public IDataView Dataset { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

namespace Microsoft.ML.AutoML
{
internal interface IMetricSettings
/// <summary>
/// Interface for metric manager.
/// </summary>
internal interface IMetricManager
{
bool IsMaximize { get; }
}

internal class BinaryMetricSettings : IMetricSettings
internal class BinaryMetricManager : IMetricManager
{
public BinaryClassificationMetric Metric { get; set; }

Expand All @@ -33,7 +36,7 @@ internal class BinaryMetricSettings : IMetricSettings
};
}

internal class MultiClassMetricSettings : IMetricSettings
internal class MultiClassMetricManager : IMetricManager
{
public MulticlassClassificationMetric Metric { get; set; }

Expand All @@ -52,7 +55,7 @@ internal class MultiClassMetricSettings : IMetricSettings
};
}

internal class RegressionMetricSettings : IMetricSettings
internal class RegressionMetricManager : IMetricManager
{
public RegressionMetric Metric { get; set; }

Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.AutoML/AutoMLExperiment/IMonitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

namespace Microsoft.ML.AutoML
{
internal interface IMonitor
/// <summary>
/// instance for monitor, which is used by <see cref="AutoMLExperiment"/> to report training progress.
/// </summary>
public interface IMonitor
{
void ReportCompletedTrial(TrialResult result);

Expand Down
Loading

0 comments on commit 9336dae

Please sign in to comment.