-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Enabling Ranking Cross Validation #5263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
39f1e6a
6a8a242
827c1c6
1621e99
b47d353
e4a714f
e2c9a93
6e2792a
567865e
41905b8
25afec7
fc883a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,11 +67,23 @@ internal ExperimentBase(MLContext context, | |
public ExperimentResult<TMetrics> Execute(IDataView trainData, string labelColumnName = DefaultColumnNames.Label, | ||
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null) | ||
{ | ||
var columnInformation = new ColumnInformation() | ||
ColumnInformation columnInformation; | ||
if (_task == TaskKind.Ranking) | ||
{ | ||
LabelColumnName = labelColumnName, | ||
SamplingKeyColumnName = samplingKeyColumn | ||
}; | ||
columnInformation = new ColumnInformation() | ||
{ | ||
LabelColumnName = labelColumnName, | ||
GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId | ||
}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: As per the feedback we got from @justinormont today, I think it would be better to set both columnInformation = new ColumnInformation()
{
LabelColumnName = labelColumnName,
SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId,
GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
} With your current implementation it won't make any difference to do this, but I do think this might be clearer for future AutoML.NET developers. A similar change would need to take place in the other overload that receives a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would lean towards deferring this to the next update. I will take a quick look, but having a column with two column info seems to be causing issues. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just a two line change (adding the line in here, and in the other overload), and it's just to have it clear in the columnInformation object that we'll be using the In reply to: 452987600 [](ancestors = 452987600) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be clear, mapping a groupId column to both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I know the current implementation throws if they're not the same. That's way I suggested using the In general if the user provides a In reply to: 453055824 [](ancestors = 453055824) |
||
} | ||
else | ||
{ | ||
columnInformation = new ColumnInformation() | ||
{ | ||
LabelColumnName = labelColumnName, | ||
SamplingKeyColumnName = samplingKeyColumn | ||
}; | ||
} | ||
return Execute(trainData, columnInformation, preFeaturizer, progressHandler); | ||
} | ||
|
||
|
@@ -102,19 +114,28 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation | |
const int crossValRowCountThreshold = 15000; | ||
|
||
var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold); | ||
var samplingKeyColumnName = GetSamplingKey(columnInformation?.GroupIdColumnName, columnInformation?.SamplingKeyColumnName); | ||
if (rowCount < crossValRowCountThreshold) | ||
{ | ||
const int numCrossValFolds = 10; | ||
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumnName); | ||
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, samplingKeyColumnName); | ||
return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler); | ||
} | ||
else | ||
{ | ||
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName); | ||
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, samplingKeyColumnName); | ||
return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler); | ||
} | ||
} | ||
|
||
private string GetSamplingKey(string groupIdColumnName, string samplingKeyColumnName) | ||
{ | ||
UserInputValidationUtil.ValidateSamplingKey(samplingKeyColumnName, groupIdColumnName, _task); | ||
if ( _task == TaskKind.Ranking) | ||
return groupIdColumnName ?? DefaultColumnNames.GroupId; | ||
return samplingKeyColumnName; | ||
} | ||
|
||
Lynx1820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// <summary> | ||
/// Executes an AutoML experiment. | ||
/// </summary> | ||
|
@@ -194,7 +215,8 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData, ui | |
IProgress<CrossValidationRunDetail<TMetrics>> progressHandler = null) | ||
{ | ||
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds); | ||
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName); | ||
var samplingKeyColumnName = GetSamplingKey(columnInformation?.GroupIdColumnName, columnInformation?.SamplingKeyColumnName); | ||
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, samplingKeyColumnName); | ||
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
using Microsoft.ML.TestFramework; | ||
using Microsoft.ML.TestFramework.Attributes; | ||
using Microsoft.ML.TestFrameworkCommon; | ||
using Microsoft.ML.Trainers.LightGbm; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
using static Microsoft.ML.DataOperationsCatalog; | ||
|
@@ -156,6 +157,40 @@ public void AutoFitRankingTest() | |
Assert.True(col.Name == expectedOutputNames[col.Index]); | ||
} | ||
|
||
[LightGBMFact] | ||
public void AutoFitRankingCVTest() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the way experiments are used within codegen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If I recall it correctly, if your dataset has less than 15000 lines of data, AutoML will run CrossValidation automatically, if you have more than 15000 piece of data, it will use train-test split instead. So the rest of tests in AutoFitTests should all be CV runs considering that the dataset it uses is really small. (@justinormont correct me if I'm wrong) tests start with AutoFit should test AutoML ranking experiment API, so you shouldn't have to create your pipeline from scratch in this test, If you just want to test Ranking.CrossValidation command, considering rename it more specifically. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the other way around. If it has less than 15000 it runs train test split automatically on one of the Execute overloads, if it has more it runs CV. This only happens on 1 overload, but I believe Keren isn't using that overload on her tests. In reply to: 447156630 [](ancestors = 447156630) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added CV testing for ranking only. I think it would be good to add testing for other task as well in the future. |
||
{ | ||
string labelColumnName = "Label"; | ||
string groupIdColumnName = "GroupIdCustom"; | ||
string featuresColumnVectorNameA = "FeatureVectorA"; | ||
Lynx1820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
string featuresColumnVectorNameB = "FeatureVectorB"; | ||
uint numFolds = 3; | ||
|
||
var mlContext = new MLContext(1); | ||
var reader = new TextLoader(mlContext, GetLoaderArgsRank(labelColumnName, groupIdColumnName, | ||
featuresColumnVectorNameA, featuresColumnVectorNameB)); | ||
var trainDataView = reader.Load(new MultiFileSource(DatasetUtil.GetMLSRDataset())); | ||
|
||
CrossValidationExperimentResult<RankingMetrics> experimentResult = mlContext.Auto() | ||
.CreateRankingExperiment(new RankingExperimentSettings() { GroupIdColumnName = groupIdColumnName, MaxExperimentTimeInSeconds = 5 }) | ||
Lynx1820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.Execute(trainDataView, numFolds, | ||
new ColumnInformation() | ||
{ | ||
LabelColumnName = labelColumnName, | ||
GroupIdColumnName = groupIdColumnName | ||
}); | ||
|
||
CrossValidationRunDetail<RankingMetrics> bestRun = experimentResult.BestRun; | ||
Assert.True(experimentResult.RunDetails.Count() > 0); | ||
var enumerator = bestRun.Results.GetEnumerator(); | ||
while (enumerator.MoveNext()) | ||
{ | ||
var model = enumerator.Current; | ||
Assert.True(model.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > .4); | ||
Assert.True(model.ValidationMetrics.DiscountedCumulativeGains.Max() > 19); | ||
} | ||
} | ||
|
||
[Fact] | ||
public void AutoFitRecommendationTest() | ||
{ | ||
|
Uh oh!
There was an error while loading. Please reload this page.