@@ -80,7 +80,7 @@ public CrossValSummaryRunner(MLContext context,
8080 var bestModel = trainResults . ElementAt ( bestFoldIndex ) . model ;
8181
8282 // Get the average metrics across all folds
83- var avgScore = GetAverageOfNonNaNScores ( trainResults . Select ( x => x . score ) ) ;
83+ var avgScore = GetAverageOfNonNaNAndNonInfinityScores ( trainResults . Select ( x => x . score ) ) ;
8484 var indexClosestToAvg = GetIndexClosestToAverage ( trainResults . Select ( r => r . score ) , avgScore ) ;
8585 var metricsClosestToAvg = trainResults [ indexClosestToAvg ] . metrics ;
8686 var avgMetrics = GetAverageMetrics ( trainResults . Select ( x => x . metrics ) , metricsClosestToAvg ) ;
@@ -99,14 +99,14 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
9999 Contracts . Assert ( newMetrics != null ) ;
100100
101101 var result = new BinaryClassificationMetrics (
102- auc : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . AreaUnderRocCurve ) ) ,
103- accuracy : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . Accuracy ) ) ,
104- positivePrecision : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . PositivePrecision ) ) ,
105- positiveRecall : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . PositiveRecall ) ) ,
106- negativePrecision : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . NegativePrecision ) ) ,
107- negativeRecall : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . NegativeRecall ) ) ,
108- f1Score : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . F1Score ) ) ,
109- auprc : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . AreaUnderPrecisionRecallCurve ) ) ,
102+ auc : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . AreaUnderRocCurve ) ) ,
103+ accuracy : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . Accuracy ) ) ,
104+ positivePrecision : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . PositivePrecision ) ) ,
105+ positiveRecall : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . PositiveRecall ) ) ,
106+ negativePrecision : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . NegativePrecision ) ) ,
107+ negativeRecall : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . NegativeRecall ) ) ,
108+ f1Score : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . F1Score ) ) ,
109+ auprc : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . AreaUnderPrecisionRecallCurve ) ) ,
110110 // Return ConfusionMatrix from the fold closest to average score
111111 confusionMatrix : ( metricsClosestToAvg as BinaryClassificationMetrics ) . ConfusionMatrix ) ;
112112 return result as TMetrics ;
@@ -118,12 +118,12 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
118118 Contracts . Assert ( newMetrics != null ) ;
119119
120120 var result = new MulticlassClassificationMetrics (
121- accuracyMicro : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . MicroAccuracy ) ) ,
122- accuracyMacro : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . MacroAccuracy ) ) ,
123- logLoss : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . LogLoss ) ) ,
124- logLossReduction : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . LogLossReduction ) ) ,
121+ accuracyMicro : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . MicroAccuracy ) ) ,
122+ accuracyMacro : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . MacroAccuracy ) ) ,
123+ logLoss : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . LogLoss ) ) ,
124+ logLossReduction : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . LogLossReduction ) ) ,
125125 topKPredictionCount : newMetrics . ElementAt ( 0 ) . TopKPredictionCount ,
126- topKAccuracy : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . TopKAccuracy ) ) ,
126+ topKAccuracy : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . TopKAccuracy ) ) ,
127127 // Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score
128128 perClassLogLoss : ( metricsClosestToAvg as MulticlassClassificationMetrics ) . PerClassLogLoss . ToArray ( ) ,
129129 confusionMatrix : ( metricsClosestToAvg as MulticlassClassificationMetrics ) . ConfusionMatrix ) ;
@@ -136,23 +136,29 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
136136 Contracts . Assert ( newMetrics != null ) ;
137137
138138 var result = new RegressionMetrics (
139- l1 : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . MeanAbsoluteError ) ) ,
140- l2 : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . MeanSquaredError ) ) ,
141- rms : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . RootMeanSquaredError ) ) ,
142- lossFunction : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . LossFunction ) ) ,
143- rSquared : GetAverageOfNonNaNScores ( newMetrics . Select ( x => x . RSquared ) ) ) ;
139+ l1 : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . MeanAbsoluteError ) ) ,
140+ l2 : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . MeanSquaredError ) ) ,
141+ rms : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . RootMeanSquaredError ) ) ,
142+ lossFunction : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . LossFunction ) ) ,
143+ rSquared : GetAverageOfNonNaNAndNonInfinityScores ( newMetrics . Select ( x => x . RSquared ) ) ) ;
144144 return result as TMetrics ;
145145 }
146146
147147 throw new NotImplementedException ( $ "Metric { typeof ( TMetrics ) } not implemented") ;
148148 }
149149
150- private static double GetAverageOfNonNaNScores ( IEnumerable < double > results )
150+ /// <summary>
151+ /// return the average of non-nan and non-infinity scores from <paramref name="results"/>, if all scores are nan/infinity, nan will be returned.
152+ /// </summary>
153+ /// <param name="results"></param>
154+ /// <returns></returns>
155+ private static double GetAverageOfNonNaNAndNonInfinityScores ( IEnumerable < double > results )
151156 {
152- var newResults = results . Where ( r => ! double . IsNaN ( r ) ) ;
153- // Return NaN iff all scores are NaN
157+ var newResults = results . Where ( r => ! double . IsNaN ( r ) && ! double . IsInfinity ( r ) ) ;
158+ // Return NaN iff all scores are NaN or Infinity
154159 if ( newResults . Count ( ) == 0 )
155160 return double . NaN ;
161+
156162 // Return average of non-NaN scores otherwise
157163 return newResults . Average ( r => r ) ;
158164 }
0 commit comments