Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PFI issue in binary classification #4587

Merged
merged 10 commits into from
Jan 8, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,15 @@ internal CalibratedBinaryClassificationMetrics(IHost host, DataViewRow overallRe
LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction);
Entropy = Fetch(BinaryClassifierEvaluator.Entropy);
}

[BestFriend]
internal CalibratedBinaryClassificationMetrics(double auc, double accuracy, double positivePrecision, double positiveRecall,
double negativePrecision, double negativeRecall, double f1Score, double auprc, double logLoss, double logLossReduction, double entropy)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I believe this constructor is no longer used (after you removed the other code handling the calibrated case). Is there a reason the keep this constructor?

: base(auc, accuracy, positivePrecision, positiveRecall, negativePrecision, negativeRecall, f1Score, auprc)
{
LogLoss = logLoss;
LogLossReduction = logLossReduction;
Entropy = entropy;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public static ImmutableArray<TResult>
int processedCnt = 0;
int nextFeatureIndex = 0;
var shuffleRand = RandomUtils.Create(host.Rand.Next());
using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup"))
using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance"))
{
pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
foreach (var workingIndx in workingFeatureIndices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,16 @@ public static ImmutableArray<BinaryClassificationMetricsStatistics>
int permutationCount = 1) where TModel : class
{
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
catalog.GetEnvironment(),
predictionTransformer,
data,
() => new BinaryClassificationMetricsStatistics(),
idv => catalog.Evaluate(idv, labelColumnName),
BinaryClassifierDelta,
predictionTransformer.FeatureColumnName,
permutationCount,
useFeatureWeightFilter,
numberOfExamplesToUse);
catalog.GetEnvironment(),
predictionTransformer,
data,
() => new BinaryClassificationMetricsStatistics(),
idv => catalog.EvaluateNonCalibrated(idv, labelColumnName),
Copy link
Contributor

@harishsk harishsk Jan 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small Nit:
The only difference between the if (isCalibratedModel) and the else case is the idv parameter. Is it possible to make this a bit more readable by factoring out just that line and using a single call to the PermutationFeatureImportance constructor? #Resolved

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolving this comment, since isCalibratedModel has been removed.


In reply to: 362685943 [](ancestors = 362685943)

BinaryClassifierDelta,
predictionTransformer.FeatureColumnName,
permutationCount,
useFeatureWeightFilter,
numberOfExamplesToUse);
}

private static BinaryClassificationMetrics BinaryClassifierDelta(
Expand Down
30 changes: 30 additions & 0 deletions test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,36 @@ public void TestPfiBinaryClassificationOnSparseFeatures(bool saveModel)

Done();
}

[Fact]
public void TestBinaryClassificationWithoutCalibrator()
Copy link
Contributor

@harishsk harishsk Jan 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test does not Assert anything. Can you please include Asserts for the relevant results that this test is supposed to verify? #Resolved

{
var dataPath = GetDataPath("breast-cancer.txt");
var ff = ML.BinaryClassification.Trainers.FastForest();
var data = ML.Data.LoadFromTextFile(dataPath,
new[] { new TextLoader.Column("Label", DataKind.Boolean, 0),
new TextLoader.Column("Features", DataKind.Single, 1, 9) });
var model = ff.Fit(data);
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);

// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean));
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Accuracy.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.Accuracy.Mean));
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositivePrecision.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositivePrecision.Mean));
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositiveRecall.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean));
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean));
Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativeRecall.Mean));
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.F1Score.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.F1Score.Mean));
Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean));
Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean));
}
#endregion

#region Multiclass Classification Tests
Expand Down