diff --git a/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs b/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs index 04accc4754..0079be3ade 100644 --- a/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs +++ b/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs @@ -123,8 +123,7 @@ private static TMetrics GetAverageMetrics(IEnumerable metrics, TMetric logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)), logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)), topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount, - topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)), - // Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score + topKAccuracies: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.TopKAccuracyForAllK)), perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(), confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix); return result as TMetrics; @@ -163,7 +162,6 @@ private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable x.ElementAt(i))); } return arr; diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 07bec9516e..96ee121400 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -1035,7 +1035,13 @@ private static List GetMetricNames(IChannel ch, DataViewSchema schema, D names = editor.Commit(); } foreach (var name in names.Items(all: true)) - metricNames.Add(string.Format("{0}{1}", metricName, name.Value)); + { + var tryNaming = string.Format(metricName, name.Value); + if (tryNaming == metricName) // metricName wasn't a format string, so just append slotname + tryNaming = (string.Format("{0}{1}", metricName, name.Value)); + + metricNames.Add(tryNaming); + } } } ch.Assert(metricNames.Count == metricCount); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs index 05d8f050d0..c6d2495506 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data @@ -71,16 +73,22 @@ public sealed class MulticlassClassificationMetrics public double MicroAccuracy { get; } /// - /// If is positive, this is the relative number of examples where - /// the true label is one of the top-k predicted labels by the predictor. + /// Convenience method for "TopKAccuracyForAllK[TopKPredictionCount - 1]". If is positive, + /// this is the relative number of examples where + /// the true label is one of the top K predicted labels by the predictor. /// - public double TopKAccuracy { get; } + public double TopKAccuracy => TopKAccuracyForAllK?.LastOrDefault() ?? 0; /// - /// If positive, this indicates the K in . + /// If positive, this indicates the K in and . /// public int TopKPredictionCount { get; } + /// + /// Returns the top K accuracy for all K from 1 to the value of TopKPredictionCount. + /// + public IReadOnlyList TopKAccuracyForAllK { get; } + /// /// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier /// with respect to how much the predicted probabilities diverge from the true class label. Lower @@ -115,8 +123,9 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss); LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction); TopKPredictionCount = topKPredictionCount; + if (topKPredictionCount > 0) - TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy); + TopKAccuracyForAllK = RowCursorUtils.Fetch>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray(); var perClassLogLoss = RowCursorUtils.Fetch>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray(); @@ -124,20 +133,20 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, } internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, - int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss) + int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss) { MicroAccuracy = accuracyMicro; MacroAccuracy = accuracyMacro; LogLoss = logLoss; LogLossReduction = logLossReduction; TopKPredictionCount = topKPredictionCount; - TopKAccuracy = topKAccuracy; + TopKAccuracyForAllK = topKAccuracies; PerClassLogLoss = perClassLogLoss.ToImmutableArray(); } internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, - int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss, ConfusionMatrix confusionMatrix) - : this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracy, perClassLogLoss) + int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss, ConfusionMatrix confusionMatrix) + : this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracies, perClassLogLoss) { ConfusionMatrix = confusionMatrix; } diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs index e616d19a55..f084867b96 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs @@ -41,6 +41,7 @@ public sealed class Arguments public const string AccuracyMicro = "Accuracy(micro-avg)"; public const string AccuracyMacro = "Accuracy(macro-avg)"; public const string TopKAccuracy = "Top K accuracy"; + public const string AllTopKAccuracy = "Top K accuracies"; public const string PerClassLogLoss = "Per class log-loss"; public const string LogLoss = "Log-loss"; public const string LogLossReduction = "Log-loss reduction"; @@ -147,6 +148,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre var logLoss = new List(); var logLossRed = new List(); var topKAcc = new List(); + var allTopK = new List(); var perClassLogLoss = new List(); var counts = new List(); var weights = new List(); @@ -171,7 +173,10 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre logLoss.Add(agg.UnweightedCounters.LogLoss); logLossRed.Add(agg.UnweightedCounters.Reduction); if (agg.UnweightedCounters.OutputTopKAcc > 0) + { topKAcc.Add(agg.UnweightedCounters.TopKAccuracy); + allTopK.Add(agg.UnweightedCounters.AllTopKAccuracy); + } perClassLogLoss.Add(agg.UnweightedCounters.PerClassLogLoss); confStratCol.AddRange(agg.UnweightedCounters.ConfusionTable.Select(x => stratColKey)); @@ -188,7 +193,10 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre logLoss.Add(agg.WeightedCounters.LogLoss); logLossRed.Add(agg.WeightedCounters.Reduction); if (agg.WeightedCounters.OutputTopKAcc > 0) + { topKAcc.Add(agg.WeightedCounters.TopKAccuracy); + allTopK.Add(agg.WeightedCounters.AllTopKAccuracy); + } perClassLogLoss.Add(agg.WeightedCounters.PerClassLogLoss); weights.AddRange(agg.WeightedCounters.ConfusionTable); } @@ -210,7 +218,15 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre overallDvBldr.AddColumn(LogLoss, NumberDataViewType.Double, logLoss.ToArray()); overallDvBldr.AddColumn(LogLossReduction, NumberDataViewType.Double, logLossRed.ToArray()); if (aggregator.UnweightedCounters.OutputTopKAcc > 0) + { overallDvBldr.AddColumn(TopKAccuracy, NumberDataViewType.Double, topKAcc.ToArray()); + + ValueGetter>> getKSlotNames = + (ref VBuffer> dst) => + dst = new VBuffer>(allTopK.First().Length, Enumerable.Range(1, allTopK.First().Length).Select(i => new ReadOnlyMemory(i.ToString().ToCharArray())).ToArray()); + overallDvBldr.AddColumn(AllTopKAccuracy, getKSlotNames, NumberDataViewType.Double, allTopK.ToArray()); + } + overallDvBldr.AddColumn(PerClassLogLoss, aggregator.GetSlotNames, NumberDataViewType.Double, perClassLogLoss.ToArray()); var confDvBldr = new ArrayDataViewBuilder(Host); @@ -246,9 +262,10 @@ public sealed class Counters private double _totalLogLoss; private double _numInstances; private double _numCorrect; - private double _numCorrectTopK; private readonly double[] _sumWeightsOfClass; private readonly double[] _totalPerClassLogLoss; + private readonly double[] _seenRanks; + public readonly double[][] ConfusionTable; public double MicroAvgAccuracy { get { return _numInstances > 0 ? _numCorrect / _numInstances : 0; } } @@ -291,7 +308,8 @@ public double Reduction } } - public double TopKAccuracy { get { return _numInstances > 0 ? _numCorrectTopK / _numInstances : 0; } } + public double TopKAccuracy => !(OutputTopKAcc is null) ? AllTopKAccuracy[OutputTopKAcc.Value - 1] : 0d; + public double[] AllTopKAccuracy => CumulativeSum(_seenRanks.Take(OutputTopKAcc ?? 0).Select(l => l / _numInstances)).ToArray(); // The per class average log loss is calculated by dividing the weighted sum of the log loss of examples // in each class by the total weight of examples in that class. @@ -316,14 +334,12 @@ public Counters(int numClasses, int? outputTopKAcc) ConfusionTable = new double[numClasses][]; for (int i = 0; i < ConfusionTable.Length; i++) ConfusionTable[i] = new double[numClasses]; + + _seenRanks = new double[numClasses + 1]; } - public void Update(int[] indices, double loglossCurr, int label, float weight) + public void Update(int seenRank, int assigned, double loglossCurr, int label, float weight) { - Contracts.Assert(Utils.Size(indices) == _numClasses); - - int assigned = indices[0]; - _numInstances += weight; if (label < _numClasses) @@ -334,23 +350,30 @@ public void Update(int[] indices, double loglossCurr, int label, float weight) if (label < _numClasses) _totalPerClassLogLoss[label] += loglossCurr * weight; - if (assigned == label) + _seenRanks[seenRank] += weight; + + if (seenRank == 0) // Prediction matched label { _numCorrect += weight; ConfusionTable[label][label] += weight; - _numCorrectTopK += weight; } else if (label < _numClasses) { - if (OutputTopKAcc > 0) - { - int idx = Array.IndexOf(indices, label); - if (0 <= idx && idx < OutputTopKAcc) - _numCorrectTopK += weight; - } ConfusionTable[label][assigned] += weight; } } + + private static IEnumerable CumulativeSum(IEnumerable s) + { + double sum = 0; + + foreach (var x in s) + { + sum += x; + yield return sum; + } + } + } private ValueGetter _labelGetter; @@ -359,7 +382,6 @@ public void Update(int[] indices, double loglossCurr, int label, float weight) private VBuffer _scores; private readonly float[] _scoresArr; - private int[] _indicesArr; private const float Epsilon = (float)1e-15; @@ -380,6 +402,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, int s Host.Assert(Utils.Size(classNames) == scoreVectorSize); _scoresArr = new float[scoreVectorSize]; + UnweightedCounters = new Counters(scoreVectorSize, outputTopKAcc); Weighted = weighted; WeightedCounters = Weighted ? new Counters(scoreVectorSize, outputTopKAcc) : null; @@ -400,6 +423,7 @@ internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema sche if (schema.Weight.HasValue) _weightGetter = row.GetGetter(schema.Weight.Value); + } public override void ProcessRow() @@ -437,16 +461,10 @@ public override void ProcessRow() } } - // Sort classes by prediction strength. - // Use stable OrderBy instead of Sort(), which may give different results on different machines. - if (Utils.Size(_indicesArr) < _scoresArr.Length) - _indicesArr = new int[_scoresArr.Length]; - int j = 0; - foreach (var index in Enumerable.Range(0, _scoresArr.Length).OrderByDescending(i => _scoresArr[i])) - _indicesArr[j++] = index; - var intLabel = (int)label; + var wasKnownLabel = true; + // log-loss double logloss; if (intLabel < _scoresArr.Length) @@ -461,11 +479,37 @@ public override void ProcessRow() // Penalize logloss if the label was not seen during training logloss = -Math.Log(Epsilon); _numUnknownClassInstances++; + wasKnownLabel = false; } - UnweightedCounters.Update(_indicesArr, logloss, intLabel, 1); + // Get the probability that the CORRECT label has: (best case is that it's the highest probability): + var correctProba = !wasKnownLabel ? 0 : _scoresArr[intLabel]; + + // Find the rank of the *correct* label (in _scoresArr[]). If the correct (ground truth) labels gets rank 0, + // it means the model assigned it the highest probability (that's ideal). Rank 1 would mean our model + // gives the real label the 2nd highest probabality, etc. + // The rank will be from 0 to N. (Not N-1). Rank N is used for unrecognized values. + // + // Tie breaking: What if we have probabilities that are equal to the correct prediction (eg, a:0.1, b:0.1, + // c:0.1, d:0.6, e:0.1 where c is the correct label). + // This actually happens a lot with some models. We handle ties by assigning rank in order of first + // appearance. In this example, we assign c the rank of 3, because d has a higher probability and a and b + // are sequentially first. + int rankOfCorrectLabel = 0; + int assigned = 0; + for (int i=0; i < _scoresArr.Length; i++) + { + if (_scoresArr[i] > correctProba || (_scoresArr[i] == correctProba && i < intLabel)) + rankOfCorrectLabel++; + + // This is the assigned "prediction" of the model if it has the highest probability. + if (_scoresArr[assigned] < _scoresArr[i]) + assigned = i; + } + + UnweightedCounters.Update(rankOfCorrectLabel, assigned, logloss, intLabel, 1); if (WeightedCounters != null) - WeightedCounters.Update(_indicesArr, logloss, intLabel, weight); + WeightedCounters.Update(rankOfCorrectLabel, assigned, logloss, intLabel, weight); } protected override List GetWarningsCore() @@ -882,13 +926,16 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary GetOverallMetricColumns() } yield return new MetricColumn("LogLoss", MulticlassClassificationEvaluator.LogLoss, MetricColumn.Objective.Minimize); yield return new MetricColumn("LogLossReduction", MulticlassClassificationEvaluator.LogLossReduction); + yield return new MetricColumn("TopKAccuracyForAllK", MulticlassClassificationEvaluator.AllTopKAccuracy, isVector: true); } private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index 02e8832f0e..8be82d7fde 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -251,7 +252,7 @@ private static MulticlassClassificationMetrics MulticlassClassificationDelta( logLoss: a.LogLoss - b.LogLoss, logLossReduction: a.LogLossReduction - b.LogLossReduction, topKPredictionCount: a.TopKPredictionCount, - topKAccuracy: a.TopKAccuracy - b.TopKAccuracy, + topKAccuracies: a?.TopKAccuracyForAllK?.Zip(b.TopKAccuracyForAllK, (a,b)=>a-b)?.ToArray(), perClassLogLoss: perClassLogLoss ); } diff --git a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs index 0f9b336d84..72b8d33bee 100644 --- a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs @@ -61,7 +61,7 @@ public void BinaryMetricsPerfectTest() [Fact] public void MulticlassMetricsGetScoreTest() { - var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] {}); + var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, new double[] {0.5}, new double[] {}); Assert.Equal(0.1, GetScore(metrics, MulticlassClassificationMetric.MicroAccuracy)); Assert.Equal(0.2, GetScore(metrics, MulticlassClassificationMetric.MacroAccuracy)); Assert.Equal(0.3, GetScore(metrics, MulticlassClassificationMetric.LogLoss)); @@ -72,7 +72,7 @@ public void MulticlassMetricsGetScoreTest() [Fact] public void MulticlassMetricsNonPerfectTest() { - var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] { }); + var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, new double[] { 0.5 }, new double[] { }); Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.MacroAccuracy)); Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.MicroAccuracy)); Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss)); @@ -83,7 +83,7 @@ public void MulticlassMetricsNonPerfectTest() [Fact] public void MulticlassMetricsPerfectTest() { - var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(1, 1, 0, 1, 0, 1, new double[] { }); + var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(1, 1, 0, 1, 0, new double[] { 1 }, new double[] { }); Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.MicroAccuracy)); Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.MacroAccuracy)); Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss)); diff --git a/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs b/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs index 828eccf9d2..3f306fdcbe 100644 --- a/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs +++ b/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs @@ -21,7 +21,7 @@ public static BinaryClassificationMetrics CreateBinaryClassificationMetrics( public static MulticlassClassificationMetrics CreateMulticlassClassificationMetrics( double accuracyMicro, double accuracyMacro, double logLoss, - double logLossReduction, int topK, double topKAccuracy, + double logLossReduction, int topK, double[] topKAccuracy, double[] perClassLogLoss) { return CreateInstance(accuracyMicro, diff --git a/test/Microsoft.ML.PerformanceTests/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.PerformanceTests/StochasticDualCoordinateAscentClassifierBench.cs index 6908af3c01..e54df47295 100644 --- a/test/Microsoft.ML.PerformanceTests/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.PerformanceTests/StochasticDualCoordinateAscentClassifierBench.cs @@ -37,6 +37,8 @@ public class StochasticDualCoordinateAscentClassifierBench : WithExtraMetrics private PredictionEngine _predictionEngine; private IrisData[][] _batches; private MulticlassClassificationMetrics _metrics; + private MulticlassClassificationEvaluator _evaluator; + private IDataView _scoredIrisTestData; protected override IEnumerable GetMetrics() { @@ -118,7 +120,7 @@ public void TrainSentiment() _consumer.Consume(predicted); } - [GlobalSetup(Targets = new string[] { nameof(PredictIris), nameof(PredictIrisBatchOf1), nameof(PredictIrisBatchOf2), nameof(PredictIrisBatchOf5) })] + [GlobalSetup(Targets = new string[] { nameof(PredictIris), nameof(PredictIrisBatchOf1), nameof(PredictIrisBatchOf2), nameof(PredictIrisBatchOf5), nameof(EvaluateMetrics) })] public void SetupPredictBenchmarks() { _trainedModel = Train(_dataPath); @@ -141,9 +143,9 @@ public void SetupPredictBenchmarks() var loader = new TextLoader(_mlContext, options: options); IDataView testData = loader.Load(_dataPath); - IDataView scoredTestData = _trainedModel.Transform(testData); - var evaluator = new MulticlassClassificationEvaluator(_mlContext, new MulticlassClassificationEvaluator.Arguments()); - _metrics = evaluator.Evaluate(scoredTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel); + _scoredIrisTestData = _trainedModel.Transform(testData); + _evaluator = new MulticlassClassificationEvaluator(_mlContext, new MulticlassClassificationEvaluator.Arguments()); + _metrics = _evaluator.Evaluate(_scoredIrisTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel); _batches = new IrisData[_batchSizes.Length][]; for (int i = 0; i < _batches.Length; i++) @@ -168,6 +170,9 @@ public void SetupPredictBenchmarks() [Benchmark] public void PredictIrisBatchOf5() => _trainedModel.Transform(_mlContext.Data.LoadFromEnumerable(_batches[2])); + + [Benchmark] + public void EvaluateMetrics() => _evaluator.Evaluate(_scoredIrisTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel); } public class IrisData diff --git a/test/Microsoft.ML.Tests/EvaluateTests.cs b/test/Microsoft.ML.Tests/EvaluateTests.cs new file mode 100644 index 0000000000..ee5c58016f --- /dev/null +++ b/test/Microsoft.ML.Tests/EvaluateTests.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.TestFramework; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests +{ + public class EvaluateTests : BaseTestClass + { + public EvaluateTests(ITestOutputHelper output) + : base(output) + { + } + + public class MulticlassEvaluatorInput + { + public float Label { get; set; } + + [VectorType(4)] + public float[] Score { get; set; } + + public float PredictedLabel { get; set; } + } + + [Fact] + public void MulticlassEvaluatorTopKArray() + { + var mlContext = new MLContext(seed: 1); + + // Notice that the probability assigned to the correct label (i.e. Score[0]) + // decreases on each row so as to get the expected TopK accuracy array hardcoded below. + var inputArray = new[] + { + new MulticlassEvaluatorInput{Label = 0, Score = new[] {0.4f, 0.3f, 0.2f, 0.1f}, PredictedLabel = 0}, + new MulticlassEvaluatorInput{Label = 0, Score = new[] {0.3f, 0.4f, 0.2f, 0.1f}, PredictedLabel = 1}, + new MulticlassEvaluatorInput{Label = 0, Score = new[] {0.2f, 0.3f, 0.4f, 0.1f}, PredictedLabel = 2}, + new MulticlassEvaluatorInput{Label = 0, Score = new[] {0.1f, 0.3f, 0.2f, 0.4f}, PredictedLabel = 3}, + }; + + var expectedTopKArray = new[] { 0.25d, 0.5d, 0.75d, 1.0d }; + + var inputDV = mlContext.Data.LoadFromEnumerable(inputArray); + var metrics = mlContext.MulticlassClassification.Evaluate(inputDV, topKPredictionCount: 4); + Assert.Equal(expectedTopKArray, metrics.TopKAccuracyForAllK); + + + // After introducing a sample whose label was unseen (i.e. the Score array doesn't assign it a probability) + // then the Top K array changes, as its values are divided by the total number of instances + // that were evaluated. + var inputArray2 = inputArray.AppendElement(new MulticlassEvaluatorInput { + Label = 5, Score = new[] { 0.1f, 0.3f, 0.2f, 0.4f }, PredictedLabel = 3 }); + + var expectedTopKArray2 = new[] { 0.2d, 0.4d, 0.6d, 0.8d }; + + var inputDV2 = mlContext.Data.LoadFromEnumerable(inputArray2); + var metrics2 = mlContext.MulticlassClassification.Evaluate(inputDV2, topKPredictionCount: 4); + var outpu2 = metrics2.TopKAccuracyForAllK.ToArray(); + for (int i = 0; i < expectedTopKArray2.Length; i++) + Assert.Equal(expectedTopKArray2[i], outpu2[i], precision: 7); + } + } +}