Skip to content

Commit

Permalink
Add non-calibrated evaluation to PFI
Browse files Browse the repository at this point in the history
  • Loading branch information
yaeldMS committed Dec 18, 2019
1 parent f1f8942 commit 60b652e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,15 @@ internal CalibratedBinaryClassificationMetrics(IHost host, DataViewRow overallRe
LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction);
Entropy = Fetch(BinaryClassifierEvaluator.Entropy);
}

[BestFriend]
internal CalibratedBinaryClassificationMetrics(double auc, double accuracy, double positivePrecision, double positiveRecall,
double negativePrecision, double negativeRecall, double f1Score, double auprc, double logLoss, double logLossReduction, double entropy)
: base(auc, accuracy, positivePrecision, positiveRecall, negativePrecision, negativeRecall, f1Score, auprc)
{
LogLoss = logLoss;
LogLossReduction = logLossReduction;
Entropy = entropy;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public static ImmutableArray<TResult>
int processedCnt = 0;
int nextFeatureIndex = 0;
var shuffleRand = RandomUtils.Create(host.Rand.Next());
using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup"))
using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance"))
{
pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
foreach (var workingIndx in workingFeatureIndices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

using System.Collections.Generic;
using System.Collections.Immutable;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -144,17 +146,43 @@ public static ImmutableArray<BinaryClassificationMetricsStatistics>
int? numberOfExamplesToUse = null,
int permutationCount = 1) where TModel : class
{
bool isCalibratedModel = false;
var type = predictionTransformer.Model.GetType();
if (type.IsGenericType)
{
var genArgs = type.GetGenericArguments();
if (Utils.Size(genArgs) == 2)
{
var calibratedModelType = typeof(CalibratedModelParametersBase<,>).MakeGenericType(genArgs);
if (calibratedModelType.IsAssignableFrom(type))
isCalibratedModel = true;
}
}
if (isCalibratedModel)
{
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
catalog.GetEnvironment(),
predictionTransformer,
data,
() => new BinaryClassificationMetricsStatistics(),
idv => catalog.Evaluate(idv, labelColumnName),
BinaryClassifierDelta,
predictionTransformer.FeatureColumnName,
permutationCount,
useFeatureWeightFilter,
numberOfExamplesToUse);
}
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
catalog.GetEnvironment(),
predictionTransformer,
data,
() => new BinaryClassificationMetricsStatistics(),
idv => catalog.Evaluate(idv, labelColumnName),
BinaryClassifierDelta,
predictionTransformer.FeatureColumnName,
permutationCount,
useFeatureWeightFilter,
numberOfExamplesToUse);
catalog.GetEnvironment(),
predictionTransformer,
data,
() => new BinaryClassificationMetricsStatistics(),
idv => catalog.EvaluateNonCalibrated(idv, labelColumnName),
BinaryClassifierDelta,
predictionTransformer.FeatureColumnName,
permutationCount,
useFeatureWeightFilter,
numberOfExamplesToUse);
}

private static BinaryClassificationMetrics BinaryClassifierDelta(
Expand All @@ -171,6 +199,23 @@ private static BinaryClassificationMetrics BinaryClassifierDelta(
auprc: a.AreaUnderPrecisionRecallCurve - b.AreaUnderPrecisionRecallCurve);
}

private static CalibratedBinaryClassificationMetrics CalibratedBinaryClassifierDelta(
CalibratedBinaryClassificationMetrics a, CalibratedBinaryClassificationMetrics b)
{
return new CalibratedBinaryClassificationMetrics(
auc: a.AreaUnderRocCurve - b.AreaUnderRocCurve,
accuracy: a.Accuracy - b.Accuracy,
positivePrecision: a.PositivePrecision - b.PositivePrecision,
positiveRecall: a.PositiveRecall - b.PositiveRecall,
negativePrecision: a.NegativePrecision - b.NegativePrecision,
negativeRecall: a.NegativeRecall - b.NegativeRecall,
f1Score: a.F1Score - b.F1Score,
auprc: a.AreaUnderPrecisionRecallCurve - b.AreaUnderPrecisionRecallCurve,
logLoss: a.LogLoss - b.LogLoss,
logLossReduction: a.LogLossReduction - b.LogLossReduction,
entropy: a.Entropy - b.Entropy);
}

#endregion Binary Classification

#region Multiclass Classification
Expand Down
12 changes: 12 additions & 0 deletions test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ public void TestPfiBinaryClassificationOnSparseFeatures(bool saveModel)

Done();
}

[Fact]
public void TestBinaryClassificationWithoutCalibrator()
{
var dataPath = GetDataPath("breast-cancer.txt");
var ff = ML.BinaryClassification.Trainers.FastForest();
var data = ML.Data.LoadFromTextFile(dataPath,
new[] { new TextLoader.Column("Label", DataKind.Boolean, 0),
new TextLoader.Column("Features", DataKind.Single, 1, 9) });
var model = ff.Fit(data);
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);
}
#endregion

#region Multiclass Classification Tests
Expand Down

0 comments on commit 60b652e

Please sign in to comment.