Skip to content

Commit 6a8a242

Browse files
committed
CV test
1 parent 39f1e6a commit 6a8a242

File tree

4 files changed

+46
-2
lines changed

4 files changed

+46
-2
lines changed

src/Microsoft.ML.AutoML/API/ColumnInference.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ public sealed class ColumnInformation
118118
public ColumnInformation()
119119
{
120120
LabelColumnName = DefaultColumnNames.Label;
121+
GroupIdColumnName = DefaultColumnNames.GroupId;
121122
CategoricalColumnNames = new Collection<string>();
122123
NumericColumnNames = new Collection<string>();
123124
TextColumnNames = new Collection<string>();

src/Microsoft.ML.AutoML/API/ExperimentBase.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData, ui
194194
IProgress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
195195
{
196196
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
197-
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName);
197+
UserInputValidationUtil.ValidateSamplingKey(columnInformation?.SamplingKeyColumnName, columnInformation?.GroupIdColumnName, _task);
198+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.GroupIdColumnName);
198199
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
199200
}
200201

src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ public static void ValidateNumberOfCVFoldsArg(uint numberOfCVFolds)
5757
}
5858
}
5959

60+
public static void ValidateSamplingKey(string samplingKeyColumnName, string groupIdColumnName, TaskKind task)
61+
{
62+
if (task == TaskKind.Ranking && samplingKeyColumnName != null && samplingKeyColumnName != groupIdColumnName)
63+
{
64+
throw new ArgumentException($"{nameof(samplingKeyColumnName)} must be the same as {nameof(groupIdColumnName)}", samplingKeyColumnName);
65+
}
66+
}
67+
6068
private static void ValidateTrainData(IDataView trainData, ColumnInformation columnInformation)
6169
{
6270
if (trainData == null)

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
using Microsoft.ML.TestFramework;
88
using Microsoft.ML.TestFramework.Attributes;
99
using Microsoft.ML.TestFrameworkCommon;
10-
using Xunit;
1110
using Microsoft.ML.Trainers.LightGbm;
11+
using Xunit;
1212
using Xunit.Abstractions;
1313
using static Microsoft.ML.DataOperationsCatalog;
1414

@@ -182,6 +182,40 @@ public void AutoFitRankingCVTest()
182182
}
183183
}
184184

185+
[LightGBMFact]
186+
public void AutoFitRankingCV2Test()
187+
{
188+
string labelColumnName = "Label";
189+
string groupIdColumnName = "GroupId";
190+
string featuresColumnVectorNameA = "FeatureVectorA";
191+
string featuresColumnVectorNameB = "FeatureVectorB";
192+
uint numFolds = 3;
193+
194+
var mlContext = new MLContext(1);
195+
var reader = new TextLoader(mlContext, GetLoaderArgsRank(labelColumnName, groupIdColumnName,
196+
featuresColumnVectorNameA, featuresColumnVectorNameB));
197+
var trainDataView = reader.Load(new MultiFileSource(DatasetUtil.GetMLSRDataset()));
198+
199+
CrossValidationExperimentResult<RankingMetrics> experimentResult = mlContext.Auto()
200+
.CreateRankingExperiment(5)
201+
.Execute(trainDataView, numFolds,
202+
new ColumnInformation()
203+
{
204+
LabelColumnName = labelColumnName,
205+
GroupIdColumnName = groupIdColumnName
206+
});
207+
208+
CrossValidationRunDetail<RankingMetrics> bestRun = experimentResult.BestRun;
209+
Assert.True(experimentResult.RunDetails.Count() > 0);
210+
var enumerator = bestRun.Results.GetEnumerator();
211+
while (enumerator.MoveNext())
212+
{
213+
var model = enumerator.Current;
214+
Assert.True(model.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > .4);
215+
Assert.True(model.ValidationMetrics.DiscountedCumulativeGains.Max() > 19);
216+
}
217+
}
218+
185219
[Fact]
186220
public void AutoFitRecommendationTest()
187221
{

0 commit comments

Comments
 (0)