Skip to content

Commit

Permalink
Change test to validate (#6599)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewditu authored Apr 21, 2023
1 parent c696e09 commit ebb5789
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ public static class AutoMLExperimentExtension
/// <returns><see cref="AutoMLExperiment"/></returns>
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation)
{
var datasetManager = new TrainTestDatasetManager()
var datasetManager = new TrainValidateDatasetManager()
{
TrainDataset = train,
TestDataset = validation
ValidateDataset = validation
};

experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,12 @@ public TrialResult Run(TrialSettings settings)
};
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand All @@ -426,7 +426,7 @@ public TrialResult Run(TrialSettings settings)
}
}

throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}

public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,12 @@ public TrialResult Run(TrialSettings settings)
};
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand All @@ -420,7 +420,7 @@ public TrialResult Run(TrialSettings settings)
}
}

throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}

public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,12 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
} as TrialResult);
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand All @@ -447,7 +447,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
}
}

throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ internal interface ICrossValidateDatasetManager
IDataView Dataset { get; set; }
}

internal interface ITrainTestDatasetManager
internal interface ITrainValidateDatasetManager
{
IDataView TrainDataset { get; set; }

IDataView TestDataset { get; set; }
IDataView ValidateDataset { get; set; }
}

internal class TrainTestDatasetManager : IDatasetManager, ITrainTestDatasetManager
internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
{
public IDataView TrainDataset { get; set; }

public IDataView TestDataset { get; set; }
public IDataView ValidateDataset { get; set; }
}

internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ public TrialResult Run(TrialSettings settings)
};
}

if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.TestDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metric = _metricManager.Evaluate(_mLContext, eval);
stopWatch.Stop();
var loss = _metricManager.IsMaximize ? -metric : metric;
Expand Down

0 comments on commit ebb5789

Please sign in to comment.