-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PFI issue in binary classification #4587
Changes from 9 commits
4ed3ca8
fe48872
b6ed4b3
0559fb7
6558551
2d53161
bd2af8e
6e53e5b
8bde46b
cd7f27c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,16 +145,16 @@ public static ImmutableArray<BinaryClassificationMetricsStatistics> | |
int permutationCount = 1) where TModel : class | ||
{ | ||
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix( | ||
catalog.GetEnvironment(), | ||
predictionTransformer, | ||
data, | ||
() => new BinaryClassificationMetricsStatistics(), | ||
idv => catalog.Evaluate(idv, labelColumnName), | ||
BinaryClassifierDelta, | ||
predictionTransformer.FeatureColumnName, | ||
permutationCount, | ||
useFeatureWeightFilter, | ||
numberOfExamplesToUse); | ||
catalog.GetEnvironment(), | ||
predictionTransformer, | ||
data, | ||
() => new BinaryClassificationMetricsStatistics(), | ||
idv => catalog.EvaluateNonCalibrated(idv, labelColumnName), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small Nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolving this comment, since In reply to: 362685943 [](ancestors = 362685943) |
||
BinaryClassifierDelta, | ||
predictionTransformer.FeatureColumnName, | ||
permutationCount, | ||
useFeatureWeightFilter, | ||
numberOfExamplesToUse); | ||
} | ||
|
||
private static BinaryClassificationMetrics BinaryClassifierDelta( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,6 +305,36 @@ public void TestPfiBinaryClassificationOnSparseFeatures(bool saveModel) | |
|
||
Done(); | ||
} | ||
|
||
[Fact] | ||
public void TestBinaryClassificationWithoutCalibrator() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test does not Assert anything. Can you please include Asserts for the relevant results that this test is supposed to verify? #Resolved |
||
{ | ||
var dataPath = GetDataPath("breast-cancer.txt"); | ||
var ff = ML.BinaryClassification.Trainers.FastForest(); | ||
var data = ML.Data.LoadFromTextFile(dataPath, | ||
new[] { new TextLoader.Column("Label", DataKind.Boolean, 0), | ||
new TextLoader.Column("Features", DataKind.Single, 1, 9) }); | ||
var model = ff.Fit(data); | ||
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data); | ||
|
||
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa | ||
Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean)); | ||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Accuracy.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.Accuracy.Mean)); | ||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositivePrecision.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositivePrecision.Mean)); | ||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositiveRecall.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean)); | ||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean)); | ||
Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativeRecall.Mean)); | ||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.F1Score.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.F1Score.Mean)); | ||
Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean)); | ||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean)); | ||
} | ||
#endregion | ||
|
||
#region Multiclass Classification Tests | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I believe this constructor is no longer used (after you removed the other code handling the calibrated case). Is there a reason the keep this constructor?