Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Test to Validate in Dataset manager #6599

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ public static class AutoMLExperimentExtension
/// <returns><see cref="AutoMLExperiment"/></returns>
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation)
{
var datasetManager = new TrainTestDatasetManager()
var datasetManager = new TrainValidateDatasetManager()
{
TrainDataset = train,
TestDataset = validation
ValidateDataset = validation
};

experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,12 @@ public TrialResult Run(TrialSettings settings)
};
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand All @@ -426,7 +426,7 @@ public TrialResult Run(TrialSettings settings)
}
}

throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}

public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,12 @@ public TrialResult Run(TrialSettings settings)
};
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand All @@ -420,7 +420,7 @@ public TrialResult Run(TrialSettings settings)
}
}

throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}

public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,12 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
} as TrialResult);
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand All @@ -447,7 +447,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
}
}

throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ internal interface ICrossValidateDatasetManager
IDataView Dataset { get; set; }
}

internal interface ITrainTestDatasetManager
internal interface ITrainValidateDatasetManager
{
IDataView TrainDataset { get; set; }

IDataView TestDataset { get; set; }
IDataView ValidateDataset { get; set; }
}

internal class TrainTestDatasetManager : IDatasetManager, ITrainTestDatasetManager
internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
{
public IDataView TrainDataset { get; set; }

public IDataView TestDataset { get; set; }
public IDataView ValidateDataset { get; set; }
}

internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ public TrialResult Run(TrialSettings settings)
};
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metric = _metricManager.Evaluate(_mLContext, eval);
stopWatch.Stop();
var loss = _metricManager.IsMaximize ? -metric : metric;
Expand Down