From 60b652ec2e2b69cf2cd1c85495026cda2d861d28 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 18 Dec 2019 17:07:28 +0200 Subject: [PATCH] Add non-calibrated evaluation to PFI --- .../CalibratedBinaryClassificationMetrics.cs | 10 +++ .../PermutationFeatureImportance.cs | 2 +- .../PermutationFeatureImportanceExtensions.cs | 65 ++++++++++++++++--- .../PermutationFeatureImportanceTests.cs | 12 ++++ 4 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs index a2d193deed5..9e21fe1cf24 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs @@ -49,5 +49,15 @@ internal CalibratedBinaryClassificationMetrics(IHost host, DataViewRow overallRe LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction); Entropy = Fetch(BinaryClassifierEvaluator.Entropy); } + + [BestFriend] + internal CalibratedBinaryClassificationMetrics(double auc, double accuracy, double positivePrecision, double positiveRecall, + double negativePrecision, double negativeRecall, double f1Score, double auprc, double logLoss, double logLossReduction, double entropy) + : base(auc, accuracy, positivePrecision, positiveRecall, negativePrecision, negativeRecall, f1Score, auprc) + { + LogLoss = logLoss; + LogLossReduction = logLossReduction; + Entropy = entropy; + } } } diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs index 1bd034ff742..648e49f565b 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs @@ -171,7 +171,7 @@ public static ImmutableArray int processedCnt = 0; int nextFeatureIndex = 0; var shuffleRand = RandomUtils.Create(host.Rand.Next()); - using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup")) + using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance")) { pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt)); foreach (var workingIndx in workingFeatureIndices) diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index 07b9c8f4352..3ffc85cfac9 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -4,7 +4,9 @@ using System.Collections.Generic; using System.Collections.Immutable; +using Microsoft.ML.Calibrators; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -144,17 +146,43 @@ public static ImmutableArray int? numberOfExamplesToUse = null, int permutationCount = 1) where TModel : class { + bool isCalibratedModel = false; + var type = predictionTransformer.Model.GetType(); + if (type.IsGenericType) + { + var genArgs = type.GetGenericArguments(); + if (Utils.Size(genArgs) == 2) + { + var calibratedModelType = typeof(CalibratedModelParametersBase<,>).MakeGenericType(genArgs); + if (calibratedModelType.IsAssignableFrom(type)) + isCalibratedModel = true; + } + } + if (isCalibratedModel) + { + return PermutationFeatureImportance.GetImportanceMetricsMatrix( + catalog.GetEnvironment(), + predictionTransformer, + data, + () => new BinaryClassificationMetricsStatistics(), + idv => catalog.Evaluate(idv, labelColumnName), + BinaryClassifierDelta, + predictionTransformer.FeatureColumnName, + permutationCount, + useFeatureWeightFilter, + numberOfExamplesToUse); + } return PermutationFeatureImportance.GetImportanceMetricsMatrix( - catalog.GetEnvironment(), - predictionTransformer, - data, - () => new BinaryClassificationMetricsStatistics(), - idv => catalog.Evaluate(idv, labelColumnName), - BinaryClassifierDelta, - predictionTransformer.FeatureColumnName, - permutationCount, - useFeatureWeightFilter, - numberOfExamplesToUse); + catalog.GetEnvironment(), + predictionTransformer, + data, + () => new BinaryClassificationMetricsStatistics(), + idv => catalog.EvaluateNonCalibrated(idv, labelColumnName), + BinaryClassifierDelta, + predictionTransformer.FeatureColumnName, + permutationCount, + useFeatureWeightFilter, + numberOfExamplesToUse); } private static BinaryClassificationMetrics BinaryClassifierDelta( @@ -171,6 +199,23 @@ private static BinaryClassificationMetrics BinaryClassifierDelta( auprc: a.AreaUnderPrecisionRecallCurve - b.AreaUnderPrecisionRecallCurve); } + private static CalibratedBinaryClassificationMetrics CalibratedBinaryClassifierDelta( + CalibratedBinaryClassificationMetrics a, CalibratedBinaryClassificationMetrics b) + { + return new CalibratedBinaryClassificationMetrics( + auc: a.AreaUnderRocCurve - b.AreaUnderRocCurve, + accuracy: a.Accuracy - b.Accuracy, + positivePrecision: a.PositivePrecision - b.PositivePrecision, + positiveRecall: a.PositiveRecall - b.PositiveRecall, + negativePrecision: a.NegativePrecision - b.NegativePrecision, + negativeRecall: a.NegativeRecall - b.NegativeRecall, + f1Score: a.F1Score - b.F1Score, + auprc: a.AreaUnderPrecisionRecallCurve - b.AreaUnderPrecisionRecallCurve, + logLoss: a.LogLoss - b.LogLoss, + logLossReduction: a.LogLossReduction - b.LogLossReduction, + entropy: a.Entropy - b.Entropy); + } + #endregion Binary Classification #region Multiclass Classification diff --git a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs index 3feca9421b7..88e301be84d 100644 --- a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs +++ b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs @@ -305,6 +305,18 @@ public void TestPfiBinaryClassificationOnSparseFeatures(bool saveModel) Done(); } + + [Fact] + public void TestBinaryClassificationWithoutCalibrator() + { + var dataPath = GetDataPath("breast-cancer.txt"); + var ff = ML.BinaryClassification.Trainers.FastForest(); + var data = ML.Data.LoadFromTextFile(dataPath, + new[] { new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Features", DataKind.Single, 1, 9) }); + var model = ff.Fit(data); + var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data); + } #endregion #region Multiclass Classification Tests