8
8
9
9
#include < maths/CBasicStatistics.h>
10
10
#include < maths/CDataFramePredictiveModel.h>
11
+ #include < maths/CSampling.h>
11
12
#include < maths/CTools.h>
13
+ #include < maths/CToolsDetail.h>
12
14
#include < maths/CTreeShapFeatureImportance.h>
13
15
14
16
#include < api/CDataFrameAnalyzer.h>
17
+ #include < api/CDataFrameTrainBoostedTreeRunner.h>
15
18
16
19
#include < test/CDataFrameAnalysisSpecificationFactory.h>
17
20
#include < test/CRandomNumbers.h>
@@ -27,12 +30,14 @@ using namespace ml;
27
30
28
31
namespace {
29
32
using TDoubleVec = std::vector<double >;
33
+ using TVector = maths::CDenseVector<double >;
30
34
using TStrVec = std::vector<std::string>;
31
35
using TRowItr = core::CDataFrame::TRowItr;
32
36
using TRowRef = core::CDataFrame::TRowRef;
33
37
using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<double >::TAccumulator;
34
38
using TMeanAccumulatorVec = std::vector<TMeanAccumulator>;
35
39
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<double >::TAccumulator;
40
+ using TMemoryMappedMatrix = maths::CMemoryMappedDenseMatrix<double >;
36
41
37
42
void setupLinearRegressionData (const TStrVec& fieldNames,
38
43
TStrVec& fieldValues,
@@ -128,6 +133,47 @@ void setupBinaryClassificationData(const TStrVec& fieldNames,
128
133
}
129
134
}
130
135
136
+ void setupMultiClassClassificationData (const TStrVec& fieldNames,
137
+ TStrVec& fieldValues,
138
+ api::CDataFrameAnalyzer& analyzer,
139
+ const TDoubleVec& weights,
140
+ const TDoubleVec& values) {
141
+ TStrVec classes{" foo" , " bar" , " baz" };
142
+ maths::CPRNG::CXorOShiro128Plus rng;
143
+ std::uniform_real_distribution<double > u01;
144
+ int numberFeatures{static_cast <int >(weights.size ())};
145
+ TDoubleVec w{weights};
146
+ int numberClasses{static_cast <int >(classes.size ())};
147
+ auto probability = [&](const TDoubleVec& row) {
148
+ TMemoryMappedMatrix W (&w[0 ], numberClasses, numberFeatures);
149
+ TVector x (numberFeatures);
150
+ for (int i = 0 ; i < numberFeatures; ++i) {
151
+ x (i) = row[i];
152
+ }
153
+ TVector logit{W * x};
154
+ return maths::CTools::softmax (std::move (logit));
155
+ };
156
+ auto target = [&](const TDoubleVec& row) {
157
+ TDoubleVec probabilities{probability (row).to <TDoubleVec>()};
158
+ return classes[maths::CSampling::categoricalSample (rng, probabilities)];
159
+ };
160
+
161
+ for (std::size_t i = 0 ; i < values.size (); i += weights.size ()) {
162
+ TDoubleVec row (weights.size ());
163
+ for (std::size_t j = 0 ; j < weights.size (); ++j) {
164
+ row[j] = values[i + j];
165
+ }
166
+
167
+ fieldValues[0 ] = target (row);
168
+ for (std::size_t j = 0 ; j < row.size (); ++j) {
169
+ fieldValues[j + 1 ] = core::CStringUtils::typeToStringPrecise (
170
+ row[j], core::CIEEE754::E_DoublePrecision);
171
+ }
172
+
173
+ analyzer.handleRecord (fieldNames, fieldValues);
174
+ }
175
+ }
176
+
131
177
struct SFixture {
132
178
rapidjson::Document
133
179
runRegression (std::size_t shapValues, TDoubleVec weights, double noiseVar = 0.0 ) {
@@ -231,6 +277,57 @@ struct SFixture {
231
277
return results;
232
278
}
233
279
280
+ rapidjson::Document runMultiClassClassification (std::size_t shapValues,
281
+ TDoubleVec&& weights) {
282
+ auto outputWriterFactory = [&]() {
283
+ return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
284
+ };
285
+ test::CDataFrameAnalysisSpecificationFactory specFactory;
286
+ api::CDataFrameAnalyzer analyzer{
287
+ specFactory.rows (s_Rows)
288
+ .memoryLimit (26000000 )
289
+ .predictionCategoricalFieldNames ({" target" })
290
+ .predictionAlpha (s_Alpha)
291
+ .predictionLambda (s_Lambda)
292
+ .predictionGamma (s_Gamma)
293
+ .predictionSoftTreeDepthLimit (s_SoftTreeDepthLimit)
294
+ .predictionSoftTreeDepthTolerance (s_SoftTreeDepthTolerance)
295
+ .predictionEta (s_Eta)
296
+ .predictionMaximumNumberTrees (s_MaximumNumberTrees)
297
+ .predictionFeatureBagFraction (s_FeatureBagFraction)
298
+ .predictionNumberTopShapValues (shapValues)
299
+ .numberClasses (3 )
300
+ .numberTopClasses (3 )
301
+ .predictionSpec (test::CDataFrameAnalysisSpecificationFactory::classification (), " target" ),
302
+ outputWriterFactory};
303
+ TStrVec fieldNames{" target" , " c1" , " c2" , " c3" , " c4" , " ." , " ." };
304
+ TStrVec fieldValues{" " , " " , " " , " " , " " , " 0" , " " };
305
+ test::CRandomNumbers rng;
306
+
307
+ TDoubleVec values;
308
+ rng.generateUniformSamples (-10.0 , 10.0 , weights.size () * s_Rows, values);
309
+
310
+ setupMultiClassClassificationData (fieldNames, fieldValues, analyzer, weights, values);
311
+
312
+ analyzer.handleRecord (fieldNames, {" " , " " , " " , " " , " " , " " , " $" });
313
+
314
+ LOG_DEBUG (<< " estimated memory usage = "
315
+ << core::CProgramCounters::counter (counter_t ::E_DFTPMEstimatedPeakMemoryUsage));
316
+ LOG_DEBUG (<< " peak memory = "
317
+ << core::CProgramCounters::counter (counter_t ::E_DFTPMPeakMemoryUsage));
318
+ LOG_DEBUG (<< " time to train = " << core::CProgramCounters::counter (counter_t ::E_DFTPMTimeToTrain)
319
+ << " ms" );
320
+
321
+ BOOST_TEST_REQUIRE (
322
+ core::CProgramCounters::counter (counter_t ::E_DFTPMPeakMemoryUsage) <
323
+ core::CProgramCounters::counter (counter_t ::E_DFTPMEstimatedPeakMemoryUsage));
324
+
325
+ rapidjson::Document results;
326
+ rapidjson::ParseResult ok (results.Parse (s_Output.str ()));
327
+ BOOST_TEST_REQUIRE (static_cast <bool >(ok) == true );
328
+ return results;
329
+ }
330
+
234
331
rapidjson::Document runRegressionWithMissingFeatures (std::size_t shapValues) {
235
332
auto outputWriterFactory = [&]() {
236
333
return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
@@ -289,9 +386,48 @@ struct SFixture {
289
386
290
387
template <typename RESULTS>
291
388
double readShapValue (const RESULTS& results, std::string shapField) {
292
- shapField = maths::CTreeShapFeatureImportance::SHAP_PREFIX + shapField;
293
- if (results[" row_results" ][" results" ][" ml" ].HasMember (shapField)) {
294
- return results[" row_results" ][" results" ][" ml" ][shapField].GetDouble ();
389
+ if (results[" row_results" ][" results" ][" ml" ].HasMember (
390
+ api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
391
+ for (const auto & shapResult :
392
+ results[" row_results" ][" results" ][" ml" ][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
393
+ .GetArray ()) {
394
+ if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
395
+ .GetString () == shapField) {
396
+ return shapResult[api::CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME]
397
+ .GetDouble ();
398
+ }
399
+ }
400
+ }
401
+ return 0.0 ;
402
+ }
403
+
404
+ template <typename RESULTS>
405
+ double readShapValue (const RESULTS& results, std::string shapField, std::string className) {
406
+ if (results[" row_results" ][" results" ][" ml" ].HasMember (
407
+ api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
408
+ for (const auto & shapResult :
409
+ results[" row_results" ][" results" ][" ml" ][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
410
+ .GetArray ()) {
411
+ if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
412
+ .GetString () == shapField) {
413
+ if (shapResult.HasMember (className)) {
414
+ return shapResult[className].GetDouble ();
415
+ }
416
+ }
417
+ }
418
+ }
419
+ return 0.0 ;
420
+ }
421
+
422
+ template <typename RESULTS>
423
+ double readClassProbability (const RESULTS& results, std::string className) {
424
+ if (results[" row_results" ][" results" ][" ml" ].HasMember (" top_classes" )) {
425
+ for (const auto & topClasses :
426
+ results[" row_results" ][" results" ][" ml" ][" top_classes" ].GetArray ()) {
427
+ if (topClasses[" class_name" ].GetString () == className) {
428
+ return topClasses[" class_probability" ].GetDouble ();
429
+ }
430
+ }
295
431
}
296
432
return 0.0 ;
297
433
}
@@ -324,9 +460,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {
324
460
c3Sum += std::fabs (c3);
325
461
c4Sum += std::fabs (c4);
326
462
// assert that no SHAP value for the dependent variable is returned
327
- BOOST_TEST_REQUIRE (result[" row_results" ][" results" ][" ml" ].HasMember (
328
- maths::CTreeShapFeatureImportance::SHAP_PREFIX +
329
- " target" ) == false );
463
+ BOOST_REQUIRE_EQUAL (readShapValue (result, " target" ), 0.0 );
330
464
}
331
465
}
332
466
@@ -421,25 +555,58 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
421
555
BOOST_REQUIRE_SMALL (maths::CBasicStatistics::variance (bias), 1e-6 );
422
556
}
423
557
558
+ BOOST_FIXTURE_TEST_CASE (testMultiClassClassificationFeatureImportanceAllShap, SFixture) {
559
+
560
+ std::size_t topShapValues{4 };
561
+ auto results{runMultiClassClassification (topShapValues, {0.5 , -0.7 , 0.2 , -0.2 })};
562
+
563
+ for (const auto & result : results.GetArray ()) {
564
+ if (result.HasMember (" row_results" )) {
565
+ double c1Sum{readShapValue (result, " c1" )};
566
+ double c2Sum{readShapValue (result, " c2" )};
567
+ double c3Sum{readShapValue (result, " c3" )};
568
+ double c4Sum{readShapValue (result, " c4" )};
569
+ // We should have at least one feature that is important
570
+ BOOST_TEST_REQUIRE ((c1Sum > 0.0 || c2Sum > 0.0 || c3Sum > 0.0 || c4Sum > 0.0 ));
571
+
572
+ // class shap values should sum(abs()) to the overall feature importance
573
+ double c1f{readShapValue (result, " c1" , " foo" )};
574
+ double c1bar{readShapValue (result, " c1" , " bar" )};
575
+ double c1baz{readShapValue (result, " c1" , " baz" )};
576
+ BOOST_REQUIRE_CLOSE (
577
+ c1Sum, std::abs (c1f) + std::abs (c1bar) + std::abs (c1baz), 1e-6 );
578
+
579
+ double c2f{readShapValue (result, " c2" , " foo" )};
580
+ double c2bar{readShapValue (result, " c2" , " bar" )};
581
+ double c2baz{readShapValue (result, " c2" , " baz" )};
582
+ BOOST_REQUIRE_CLOSE (
583
+ c2Sum, std::abs (c2f) + std::abs (c2bar) + std::abs (c2baz), 1e-6 );
584
+
585
+ double c3f{readShapValue (result, " c3" , " foo" )};
586
+ double c3bar{readShapValue (result, " c3" , " bar" )};
587
+ double c3baz{readShapValue (result, " c3" , " baz" )};
588
+ BOOST_REQUIRE_CLOSE (
589
+ c3Sum, std::abs (c3f) + std::abs (c3bar) + std::abs (c3baz), 1e-6 );
590
+
591
+ double c4f{readShapValue (result, " c4" , " foo" )};
592
+ double c4bar{readShapValue (result, " c4" , " bar" )};
593
+ double c4baz{readShapValue (result, " c4" , " baz" )};
594
+ BOOST_REQUIRE_CLOSE (
595
+ c4Sum, std::abs (c4f) + std::abs (c4bar) + std::abs (c4baz), 1e-6 );
596
+ }
597
+ }
598
+ }
599
+
424
600
BOOST_FIXTURE_TEST_CASE (testRegressionFeatureImportanceNoShap, SFixture) {
425
601
// Test that if topShapValue is set to 0, no feature importance values are returned.
426
602
std::size_t topShapValues{0 };
427
603
auto results{runRegression (topShapValues, {50.0 , 150.0 , 50.0 , -50.0 })};
428
604
429
605
for (const auto & result : results.GetArray ()) {
430
606
if (result.HasMember (" row_results" )) {
431
- BOOST_TEST_REQUIRE (
432
- result[" row_results" ][" results" ][" ml" ].HasMember (
433
- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c1" ) == false );
434
- BOOST_TEST_REQUIRE (
435
- result[" row_results" ][" results" ][" ml" ].HasMember (
436
- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c2" ) == false );
437
- BOOST_TEST_REQUIRE (
438
- result[" row_results" ][" results" ][" ml" ].HasMember (
439
- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c3" ) == false );
440
- BOOST_TEST_REQUIRE (
441
- result[" row_results" ][" results" ][" ml" ].HasMember (
442
- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c4" ) == false );
607
+ BOOST_TEST_REQUIRE (result[" row_results" ][" results" ][" ml" ].HasMember (
608
+ api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME) ==
609
+ false );
443
610
}
444
611
}
445
612
}
0 commit comments