Skip to content

Commit

Permalink
Exposing the confusion matrix (#3250)
Browse files Browse the repository at this point in the history
* Exposing the confusion matrix
  • Loading branch information
sfilipi authored Apr 20, 2019
1 parent d987294 commit 610ffcb
Show file tree
Hide file tree
Showing 11 changed files with 323 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/AnnotationUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn.HasValue && labelColumn.Value.IsKey)
if (labelColumn != null && labelColumn.Value.IsKey)
{
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
Expand Down
16 changes: 11 additions & 5 deletions src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -815,16 +815,18 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab
var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

CalibratedBinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}

return result;
}

Expand Down Expand Up @@ -879,13 +881,14 @@ public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve(
}
}
prCurve = prCurveResult;
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

CalibratedBinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
Expand Down Expand Up @@ -939,16 +942,18 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string
var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

BinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new BinaryClassificationMetrics(Host, cursor);
result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}

return result;
}

Expand Down Expand Up @@ -985,6 +990,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve(
var prCurveView = resultDict[MetricKinds.PrCurve];
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
using (var cursor = prCurveView.GetRowCursorForAllColumns())
Expand All @@ -1007,7 +1013,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve(
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new BinaryClassificationMetrics(Host, cursor);
result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
Expand Down Expand Up @@ -1377,7 +1383,7 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<str
fold = ColumnSelectingTransformer.CreateKeep(Host, fold, colsToKeep.ToArray());

string weightedConf;
var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf);
var unweightedConf = MetricWriter.GetConfusionTableAsFormattedString(Host, conf, out weightedConf);
string weightedFold;
var unweightedFold = MetricWriter.GetPerFoldResults(Host, fold, out weightedFold);
ch.Assert(string.IsNullOrEmpty(weightedConf) == string.IsNullOrEmpty(weightedFold));
Expand Down
111 changes: 69 additions & 42 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1348,20 +1348,49 @@ internal static class MetricWriter
/// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param>
/// <param name="binary">Indicates whether the confusion table is for binary classification.</param>
/// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param>
public static string GetConfusionTable(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
public static string GetConfusionTableAsFormattedString(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
{
host.CheckValue(confusionDataView, nameof(confusionDataView));
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");

// Get the class names.
int countCol;
host.Check(confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Count, out countCol), "Did not find the count column");
var type = confusionDataView.Schema[countCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
host.Check(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames.");
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
bool isWeighted = weightColumn.HasValue;

var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false);
var confusionTableString = GetConfusionTableAsString(confusionMatrix, false);

// If there is a Weight column, return the weighted confusionMatrix as well, from this function.
if (isWeighted)
{
confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true);
weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true);
}
else
weightedConfusionTable = null;

return confusionTableString;
}

public static ConfusionMatrix GetConfusionMatrix(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false)
{
host.CheckValue(confusionDataView, nameof(confusionDataView));
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");

// check that there is a Weight column, if isWeighted parameter is set to true.
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
if (getWeighted)
host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view.");

// Get the counts names.
var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count];
var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
//"The Count column does not have a text vector metadata of kind SlotNames."
host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType);

// Get the class names
var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
confusionDataView.Schema[countCol].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
host.Check(labelNames.IsDense, "Slot names vector must be dense");
countColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
host.Assert(labelNames.IsDense, "Slot names vector must be dense");

int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample);

Expand All @@ -1387,32 +1416,32 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView,

double[] precisionSums;
double[] recallSums;
var confusionTable = GetConfusionTableAsArray(confusionDataView, countCol, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
double[][] confusionTable;

var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
var confusionTableString = GetConfusionTableAsString(confusionTable, recallSums, precisionSums,
predictedLabelNames,
sampled: numConfusionTableLabels < labelNames.Length, binary: binary);
if (getWeighted)
confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
else
confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Index, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);

int weightIndex;
if (confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Weight, out weightIndex))
double[] precision = new double[numConfusionTableLabels];
double[] recall = new double[numConfusionTableLabels];
for (int i = 0; i < numConfusionTableLabels; i++)
{
confusionTable = GetConfusionTableAsArray(confusionDataView, weightIndex, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
weightedConfusionTable = GetConfusionTableAsString(confusionTable, recallSums, precisionSums,
predictedLabelNames,
sampled: numConfusionTableLabels < labelNames.Length, prefix: "Weighted ", binary: binary);
recall[i] = recallSums[i] > 0 ? confusionTable[i][i] / recallSums[i] : 0;
precision[i] = precisionSums[i] > 0 ? confusionTable[i][i] / precisionSums[i] : 0;
}
else
weightedConfusionTable = null;

return confusionTableString;
var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
bool sampled = numConfusionTableLabels < labelNames.Length;

return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary);
}

private static List<ReadOnlyMemory<char>> GetPredictedLabelNames(in VBuffer<ReadOnlyMemory<char>> labelNames, int[] labelIndexToConfIndexMap)
{
List<ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
List <ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
var values = labelNames.GetValues();
for (int i = 0; i < values.Length; i++)
{
Expand Down Expand Up @@ -1553,13 +1582,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat
}

// Get a string representation of a confusion table.
private static string GetConfusionTableAsString(double[][] confusionTable, double[] rowSums, double[] columnSums,
List<ReadOnlyMemory<char>> predictedLabelNames, string prefix = "", bool sampled = false, bool binary = true)
internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted)
{
int numLabels = Utils.Size(confusionTable);
string prefix = isWeighted ? "Weighted " : "";
int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Count;

int colWidth = numLabels == 2 ? 8 : 5;
int maxNameLen = predictedLabelNames.Max(name => name.Length);
int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length);
// If the names are too long to fit in the column header, we back off to using class indices
// in the header. This will also require putting the indices in the row, but it's better than
// the alternative of having ambiguous abbreviated column headers, or having a table potentially
Expand All @@ -1572,7 +1601,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
{
// The row label will also include the index, so a user can easily match against the header.
// In such a case, a label like "Foo" would be presented as something like "5. Foo".
rowDigitLen = Math.Max(predictedLabelNames.Count - 1, 0).ToString().Length;
rowDigitLen = Math.Max(confusionMatrix.PredictedClassesIndicators.Count - 1, 0).ToString().Length;
Contracts.Assert(rowDigitLen >= 1);
rowLabelLen += rowDigitLen + 2;
}
Expand All @@ -1591,10 +1620,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
else
rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen);

var confusionTable = confusionMatrix.Counts;
var sb = new StringBuilder();
if (numLabels == 2 && binary)
if (numLabels == 2 && confusionMatrix.IsBinary)
{
var positiveCaps = predictedLabelNames[0].ToString().ToUpper();
var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper();

var numTruePos = confusionTable[0][0];
var numFalseNeg = confusionTable[0][1];
Expand All @@ -1607,7 +1637,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl

sb.AppendLine();
sb.AppendFormat("{0}Confusion table", prefix);
if (sampled)
if (confusionMatrix.IsSampled)
sb.AppendLine(" (sampled)");
else
sb.AppendLine();
Expand All @@ -1619,7 +1649,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
sb.AppendFormat("PREDICTED {0}||", pad);
string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth);
for (int i = 0; i < numLabels; i++)
sb.AppendFormat(format, i, predictedLabelNames[i]);
sb.AppendFormat(format, i, confusionMatrix.PredictedClassesIndicators[i]);
sb.AppendLine(" Recall");
sb.AppendFormat("TRUTH {0}||", pad);
for (int i = 0; i < numLabels; i++)
Expand All @@ -1631,11 +1661,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1");
for (int i = 0; i < numLabels; i++)
{
sb.AppendFormat(rowLabelFormat, i, predictedLabelNames[i]);
sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedClassesIndicators[i]);
for (int j = 0; j < numLabels; j++)
sb.AppendFormat(format2, confusionTable[i][j]);
Double recall = rowSums[i] > 0 ? confusionTable[i][i] / rowSums[i] : 0;
sb.AppendFormat(" {0,5:F4}", recall);
sb.AppendFormat(" {0,5:F4}", confusionMatrix.PerClassRecall[i]);
sb.AppendLine();
}
sb.AppendFormat(" {0}||", pad);
Expand All @@ -1645,10 +1674,8 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
sb.AppendFormat("Precision {0}||", pad);
format = string.Format("{{0,{0}:N4}} |", colWidth + 1);
for (int i = 0; i < numLabels; i++)
{
Double precision = columnSums[i] > 0 ? confusionTable[i][i] / columnSums[i] : 0;
sb.AppendFormat(format, precision);
}
sb.AppendFormat(format, confusionMatrix.PerClassPrecision[i]);

sb.AppendLine();
return sb.ToString();
}
Expand Down Expand Up @@ -1701,7 +1728,7 @@ public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metr
if (metrics.TryGetValue(MetricKinds.Warnings, out warnings))
{
var warningTextColumn = warnings.Schema.GetColumnOrNull(MetricKinds.ColumnNames.WarningText);
if (warningTextColumn !=null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
if (warningTextColumn != null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
{
using (var cursor = warnings.GetRowCursor(warnings.Schema[MetricKinds.ColumnNames.WarningText]))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ public class BinaryClassificationMetrics
/// </remarks>
public double AreaUnderPrecisionRecallCurve { get; }

/// <summary>
/// The <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a> giving the counts of the
/// true positives, true negatives, false positives and false negatives for the two classes of data.
/// </summary>
public ConfusionMatrix ConfusionMatrix { get; }

private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, string name)
{
var column = row.Schema.GetColumnOrNull(name);
Expand All @@ -84,9 +90,9 @@ private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, str
return val;
}

internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult)
internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDataView confusionMatrix)
{
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
double Fetch(string name) => Fetch<double>(host, overallResult, name);
AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc);
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy);
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName);
Expand All @@ -95,6 +101,7 @@ internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overall
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName);
F1Score = Fetch(BinaryClassifierEvaluator.F1);
AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc);
ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix);
}

[BestFriend]
Expand Down
Loading

0 comments on commit 610ffcb

Please sign in to comment.