Skip to content

Commit 9d3e545

Browse files
BigBigMiaoBigBigMiao
authored andcommitted
filter infinity value when calculate average score
1 parent 4bc2d78 commit 9d3e545

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValRunner.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public CrossValRunner(MLContext context,
6969

7070
private static double CalcAverageScore(IEnumerable<double> scores)
7171
{
72-
var newScores = scores.Where(r => !double.IsNaN(r));
73-
// Return NaN iff all scores are NaN
72+
var newScores = scores.Where(r => !double.IsNaN(r) && !double.IsInfinity(r));
73+
// Return NaN iff all scores are NaN or infinity.
7474
if (newScores.Count() == 0)
7575
return double.NaN;
7676
return newScores.Average();

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public CrossValSummaryRunner(MLContext context,
8080
var bestModel = trainResults.ElementAt(bestFoldIndex).model;
8181

8282
// Get the average metrics across all folds
83-
var avgScore = GetAverageOfNonNaNScores(trainResults.Select(x => x.score));
83+
var avgScore = GetAverageOfNonNaNAndNonInfinityScores(trainResults.Select(x => x.score));
8484
var indexClosestToAvg = GetIndexClosestToAverage(trainResults.Select(r => r.score), avgScore);
8585
var metricsClosestToAvg = trainResults[indexClosestToAvg].metrics;
8686
var avgMetrics = GetAverageMetrics(trainResults.Select(x => x.metrics), metricsClosestToAvg);
@@ -99,14 +99,14 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
9999
Contracts.Assert(newMetrics != null);
100100

101101
var result = new BinaryClassificationMetrics(
102-
auc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderRocCurve)),
103-
accuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.Accuracy)),
104-
positivePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositivePrecision)),
105-
positiveRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositiveRecall)),
106-
negativePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativePrecision)),
107-
negativeRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativeRecall)),
108-
f1Score: GetAverageOfNonNaNScores(newMetrics.Select(x => x.F1Score)),
109-
auprc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderPrecisionRecallCurve)),
102+
auc: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.AreaUnderRocCurve)),
103+
accuracy: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.Accuracy)),
104+
positivePrecision: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.PositivePrecision)),
105+
positiveRecall: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.PositiveRecall)),
106+
negativePrecision: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.NegativePrecision)),
107+
negativeRecall: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.NegativeRecall)),
108+
f1Score: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.F1Score)),
109+
auprc: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.AreaUnderPrecisionRecallCurve)),
110110
// Return ConfusionMatrix from the fold closest to average score
111111
confusionMatrix: (metricsClosestToAvg as BinaryClassificationMetrics).ConfusionMatrix);
112112
return result as TMetrics;
@@ -118,12 +118,12 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
118118
Contracts.Assert(newMetrics != null);
119119

120120
var result = new MulticlassClassificationMetrics(
121-
accuracyMicro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MicroAccuracy)),
122-
accuracyMacro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MacroAccuracy)),
123-
logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)),
124-
logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)),
121+
accuracyMicro: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.MicroAccuracy)),
122+
accuracyMacro: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.MacroAccuracy)),
123+
logLoss: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.LogLoss)),
124+
logLossReduction: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.LogLossReduction)),
125125
topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount,
126-
topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)),
126+
topKAccuracy: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.TopKAccuracy)),
127127
// Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score
128128
perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(),
129129
confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix);
@@ -136,23 +136,29 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
136136
Contracts.Assert(newMetrics != null);
137137

138138
var result = new RegressionMetrics(
139-
l1: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanAbsoluteError)),
140-
l2: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanSquaredError)),
141-
rms: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RootMeanSquaredError)),
142-
lossFunction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LossFunction)),
143-
rSquared: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RSquared)));
139+
l1: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.MeanAbsoluteError)),
140+
l2: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.MeanSquaredError)),
141+
rms: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.RootMeanSquaredError)),
142+
lossFunction: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.LossFunction)),
143+
rSquared: GetAverageOfNonNaNAndNonInfinityScores(newMetrics.Select(x => x.RSquared)));
144144
return result as TMetrics;
145145
}
146146

147147
throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented");
148148
}
149149

150-
private static double GetAverageOfNonNaNScores(IEnumerable<double> results)
150+
/// <summary>
151+
/// return the average of non-nan and non-infinity scores from <paramref name="results"/>, if all scores are nan/infinity, nan will be returned.
152+
/// </summary>
153+
/// <param name="results"></param>
154+
/// <returns></returns>
155+
private static double GetAverageOfNonNaNAndNonInfinityScores(IEnumerable<double> results)
151156
{
152-
var newResults = results.Where(r => !double.IsNaN(r));
153-
// Return NaN iff all scores are NaN
157+
var newResults = results.Where(r => !double.IsNaN(r) && !double.IsInfinity(r));
158+
// Return NaN iff all scores are NaN or Infinity
154159
if (newResults.Count() == 0)
155160
return double.NaN;
161+
156162
// Return average of non-NaN scores otherwise
157163
return newResults.Average(r => r);
158164
}

0 commit comments

Comments
 (0)