Skip to content

Commit e764ffe

Browse files
authored
[ML] calculate feature importance for multi-class results (#1071) (#1075)
Feature importance is already calculated for multi-class models. This commit adjusts the output sent to ES so that multi-class importance can be explored. Feature importance objects are now mapped as follows (logistic) Regression: ``` { "feature_name": "feature_0", "importance": -1.3 } ``` Multi-class [class names are `foo`, `bar`, `baz`] ``` { “feature_name”: “feature_0”, “importance”: 2.0, // sum(abs()) of class importances “foo”: 1.0, “bar”: 0.5, “baz”: -0.5 }, ``` Java side change: elastic/elasticsearch#53803
1 parent 8d86ac6 commit e764ffe

12 files changed

+237
-42
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ model training. (See {ml-pull}1034[#1034].)
5858
* Add instrumentation information for supervised learning data frame analytics jobs.
5959
(See {ml-pull}1031[#1031].)
6060
* Add instrumentation information for outlier detection data frame analytics jobs.
61-
(See {ml-pull}1068[#1068].)
61+
* Write out feature importance for multi-class models. (See {ml-pull}1071[#1071])
6262

6363
=== Bug Fixes
6464

include/api/CDataFrameTrainBoostedTreeRunner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
5050
static const std::string BAYESIAN_OPTIMISATION_RESTARTS;
5151
static const std::string NUM_TOP_FEATURE_IMPORTANCE_VALUES;
5252

53+
//Output
54+
static const std::string IS_TRAINING_FIELD_NAME;
55+
static const std::string FEATURE_NAME_FIELD_NAME;
56+
static const std::string IMPORTANCE_FIELD_NAME;
57+
static const std::string FEATURE_IMPORTANCE_FIELD_NAME;
58+
5359
public:
5460
~CDataFrameTrainBoostedTreeRunner() override;
5561

include/maths/CTreeShapFeatureImportance.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
4343
using TShapWriter =
4444
std::function<void(const TSizeVec&, const TStrVec&, const TVectorVec&)>;
4545

46-
public:
47-
static const std::string SHAP_PREFIX;
48-
4946
public:
5047
CTreeShapFeatureImportance(const core::CDataFrame& frame,
5148
const CDataFrameCategoryEncoder& encoder,

include/test/CDataFrameAnalysisSpecificationFactory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
7575

7676
// Classification
7777
CDataFrameAnalysisSpecificationFactory& numberClasses(std::size_t number);
78+
CDataFrameAnalysisSpecificationFactory& numberTopClasses(std::size_t number);
7879
CDataFrameAnalysisSpecificationFactory& predictionFieldType(const std::string& type);
7980

8081
std::string outlierParams() const;
@@ -117,6 +118,7 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
117118
TRestoreSearcherSupplier* m_RestoreSearcherSupplier = nullptr;
118119
// Classification
119120
std::size_t m_NumberClasses = 2;
121+
std::size_t m_NumberTopClasses = 0;
120122
std::string m_PredictionFieldType;
121123
};
122124
}

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,32 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
163163

164164
if (featureImportance != nullptr) {
165165
featureImportance->shap(
166-
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
167-
const TStrVec& names,
168-
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
166+
row, [&writer, &classValues](
167+
const maths::CTreeShapFeatureImportance::TSizeVec& indices,
168+
const TStrVec& names,
169+
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
170+
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME);
171+
writer.StartArray();
169172
for (auto i : indices) {
170173
if (shap[i].norm() != 0.0) {
171-
writer.Key(names[i]);
172-
// TODO fixme
173-
writer.Double(shap[i](0));
174+
writer.StartObject();
175+
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME);
176+
writer.String(names[i]);
177+
if (shap[i].size() == 1) {
178+
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
179+
writer.Double(shap[i](0));
180+
} else {
181+
for (int j = 0; j < shap[i].size(); ++j) {
182+
writer.Key(classValues[j]);
183+
writer.Double(shap[i](j));
184+
}
185+
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
186+
writer.Double(shap[i].lpNorm<1>());
187+
}
188+
writer.EndObject();
174189
}
175190
}
191+
writer.EndArray();
176192
});
177193
}
178194
writer.EndObject();

lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,19 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
8282
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
8383
const TStrVec& names,
8484
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
85+
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME);
86+
writer.StartArray();
8587
for (auto i : indices) {
8688
if (shap[i].norm() != 0.0) {
87-
writer.Key(names[i]);
89+
writer.StartObject();
90+
writer.Key(CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME);
91+
writer.String(names[i]);
92+
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
8893
writer.Double(shap[i](0));
94+
writer.EndObject();
8995
}
9096
}
97+
writer.EndArray();
9198
});
9299
}
93100
writer.EndObject();

lib/api/CDataFrameTrainBoostedTreeRunner.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ const std::string CDataFrameTrainBoostedTreeRunner::STOP_CROSS_VALIDATION_EARLY{
311311
const std::string CDataFrameTrainBoostedTreeRunner::NUMBER_ROUNDS_PER_HYPERPARAMETER{"number_rounds_per_hyperparameter"};
312312
const std::string CDataFrameTrainBoostedTreeRunner::BAYESIAN_OPTIMISATION_RESTARTS{"bayesian_optimisation_restarts"};
313313
const std::string CDataFrameTrainBoostedTreeRunner::NUM_TOP_FEATURE_IMPORTANCE_VALUES{"num_top_feature_importance_values"};
314+
const std::string CDataFrameTrainBoostedTreeRunner::IS_TRAINING_FIELD_NAME{"is_training"};
315+
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME{"feature_name"};
316+
const std::string CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME{"importance"};
317+
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME{"feature_importance"};
314318
// clang-format on
315319
}
316320
}

lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc

Lines changed: 185 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
#include <maths/CBasicStatistics.h>
1010
#include <maths/CDataFramePredictiveModel.h>
11+
#include <maths/CSampling.h>
1112
#include <maths/CTools.h>
13+
#include <maths/CToolsDetail.h>
1214
#include <maths/CTreeShapFeatureImportance.h>
1315

1416
#include <api/CDataFrameAnalyzer.h>
17+
#include <api/CDataFrameTrainBoostedTreeRunner.h>
1518

1619
#include <test/CDataFrameAnalysisSpecificationFactory.h>
1720
#include <test/CRandomNumbers.h>
@@ -27,12 +30,14 @@ using namespace ml;
2730

2831
namespace {
2932
using TDoubleVec = std::vector<double>;
33+
using TVector = maths::CDenseVector<double>;
3034
using TStrVec = std::vector<std::string>;
3135
using TRowItr = core::CDataFrame::TRowItr;
3236
using TRowRef = core::CDataFrame::TRowRef;
3337
using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<double>::TAccumulator;
3438
using TMeanAccumulatorVec = std::vector<TMeanAccumulator>;
3539
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
40+
using TMemoryMappedMatrix = maths::CMemoryMappedDenseMatrix<double>;
3641

3742
void setupLinearRegressionData(const TStrVec& fieldNames,
3843
TStrVec& fieldValues,
@@ -128,6 +133,47 @@ void setupBinaryClassificationData(const TStrVec& fieldNames,
128133
}
129134
}
130135

136+
void setupMultiClassClassificationData(const TStrVec& fieldNames,
137+
TStrVec& fieldValues,
138+
api::CDataFrameAnalyzer& analyzer,
139+
const TDoubleVec& weights,
140+
const TDoubleVec& values) {
141+
TStrVec classes{"foo", "bar", "baz"};
142+
maths::CPRNG::CXorOShiro128Plus rng;
143+
std::uniform_real_distribution<double> u01;
144+
int numberFeatures{static_cast<int>(weights.size())};
145+
TDoubleVec w{weights};
146+
int numberClasses{static_cast<int>(classes.size())};
147+
auto probability = [&](const TDoubleVec& row) {
148+
TMemoryMappedMatrix W(&w[0], numberClasses, numberFeatures);
149+
TVector x(numberFeatures);
150+
for (int i = 0; i < numberFeatures; ++i) {
151+
x(i) = row[i];
152+
}
153+
TVector logit{W * x};
154+
return maths::CTools::softmax(std::move(logit));
155+
};
156+
auto target = [&](const TDoubleVec& row) {
157+
TDoubleVec probabilities{probability(row).to<TDoubleVec>()};
158+
return classes[maths::CSampling::categoricalSample(rng, probabilities)];
159+
};
160+
161+
for (std::size_t i = 0; i < values.size(); i += weights.size()) {
162+
TDoubleVec row(weights.size());
163+
for (std::size_t j = 0; j < weights.size(); ++j) {
164+
row[j] = values[i + j];
165+
}
166+
167+
fieldValues[0] = target(row);
168+
for (std::size_t j = 0; j < row.size(); ++j) {
169+
fieldValues[j + 1] = core::CStringUtils::typeToStringPrecise(
170+
row[j], core::CIEEE754::E_DoublePrecision);
171+
}
172+
173+
analyzer.handleRecord(fieldNames, fieldValues);
174+
}
175+
}
176+
131177
struct SFixture {
132178
rapidjson::Document
133179
runRegression(std::size_t shapValues, TDoubleVec weights, double noiseVar = 0.0) {
@@ -231,6 +277,57 @@ struct SFixture {
231277
return results;
232278
}
233279

280+
rapidjson::Document runMultiClassClassification(std::size_t shapValues,
281+
TDoubleVec&& weights) {
282+
auto outputWriterFactory = [&]() {
283+
return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
284+
};
285+
test::CDataFrameAnalysisSpecificationFactory specFactory;
286+
api::CDataFrameAnalyzer analyzer{
287+
specFactory.rows(s_Rows)
288+
.memoryLimit(26000000)
289+
.predictionCategoricalFieldNames({"target"})
290+
.predictionAlpha(s_Alpha)
291+
.predictionLambda(s_Lambda)
292+
.predictionGamma(s_Gamma)
293+
.predictionSoftTreeDepthLimit(s_SoftTreeDepthLimit)
294+
.predictionSoftTreeDepthTolerance(s_SoftTreeDepthTolerance)
295+
.predictionEta(s_Eta)
296+
.predictionMaximumNumberTrees(s_MaximumNumberTrees)
297+
.predictionFeatureBagFraction(s_FeatureBagFraction)
298+
.predictionNumberTopShapValues(shapValues)
299+
.numberClasses(3)
300+
.numberTopClasses(3)
301+
.predictionSpec(test::CDataFrameAnalysisSpecificationFactory::classification(), "target"),
302+
outputWriterFactory};
303+
TStrVec fieldNames{"target", "c1", "c2", "c3", "c4", ".", "."};
304+
TStrVec fieldValues{"", "", "", "", "", "0", ""};
305+
test::CRandomNumbers rng;
306+
307+
TDoubleVec values;
308+
rng.generateUniformSamples(-10.0, 10.0, weights.size() * s_Rows, values);
309+
310+
setupMultiClassClassificationData(fieldNames, fieldValues, analyzer, weights, values);
311+
312+
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
313+
314+
LOG_DEBUG(<< "estimated memory usage = "
315+
<< core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage));
316+
LOG_DEBUG(<< "peak memory = "
317+
<< core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage));
318+
LOG_DEBUG(<< "time to train = " << core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain)
319+
<< "ms");
320+
321+
BOOST_TEST_REQUIRE(
322+
core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) <
323+
core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage));
324+
325+
rapidjson::Document results;
326+
rapidjson::ParseResult ok(results.Parse(s_Output.str()));
327+
BOOST_TEST_REQUIRE(static_cast<bool>(ok) == true);
328+
return results;
329+
}
330+
234331
rapidjson::Document runRegressionWithMissingFeatures(std::size_t shapValues) {
235332
auto outputWriterFactory = [&]() {
236333
return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
@@ -289,9 +386,48 @@ struct SFixture {
289386

290387
template<typename RESULTS>
291388
double readShapValue(const RESULTS& results, std::string shapField) {
292-
shapField = maths::CTreeShapFeatureImportance::SHAP_PREFIX + shapField;
293-
if (results["row_results"]["results"]["ml"].HasMember(shapField)) {
294-
return results["row_results"]["results"]["ml"][shapField].GetDouble();
389+
if (results["row_results"]["results"]["ml"].HasMember(
390+
api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
391+
for (const auto& shapResult :
392+
results["row_results"]["results"]["ml"][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
393+
.GetArray()) {
394+
if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
395+
.GetString() == shapField) {
396+
return shapResult[api::CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME]
397+
.GetDouble();
398+
}
399+
}
400+
}
401+
return 0.0;
402+
}
403+
404+
template<typename RESULTS>
405+
double readShapValue(const RESULTS& results, std::string shapField, std::string className) {
406+
if (results["row_results"]["results"]["ml"].HasMember(
407+
api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
408+
for (const auto& shapResult :
409+
results["row_results"]["results"]["ml"][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
410+
.GetArray()) {
411+
if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
412+
.GetString() == shapField) {
413+
if (shapResult.HasMember(className)) {
414+
return shapResult[className].GetDouble();
415+
}
416+
}
417+
}
418+
}
419+
return 0.0;
420+
}
421+
422+
template<typename RESULTS>
423+
double readClassProbability(const RESULTS& results, std::string className) {
424+
if (results["row_results"]["results"]["ml"].HasMember("top_classes")) {
425+
for (const auto& topClasses :
426+
results["row_results"]["results"]["ml"]["top_classes"].GetArray()) {
427+
if (topClasses["class_name"].GetString() == className) {
428+
return topClasses["class_probability"].GetDouble();
429+
}
430+
}
295431
}
296432
return 0.0;
297433
}
@@ -324,9 +460,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {
324460
c3Sum += std::fabs(c3);
325461
c4Sum += std::fabs(c4);
326462
// assert that no SHAP value for the dependent variable is returned
327-
BOOST_TEST_REQUIRE(result["row_results"]["results"]["ml"].HasMember(
328-
maths::CTreeShapFeatureImportance::SHAP_PREFIX +
329-
"target") == false);
463+
BOOST_REQUIRE_EQUAL(readShapValue(result, "target"), 0.0);
330464
}
331465
}
332466

@@ -421,25 +555,58 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
421555
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6);
422556
}
423557

558+
BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SFixture) {
559+
560+
std::size_t topShapValues{4};
561+
auto results{runMultiClassClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})};
562+
563+
for (const auto& result : results.GetArray()) {
564+
if (result.HasMember("row_results")) {
565+
double c1Sum{readShapValue(result, "c1")};
566+
double c2Sum{readShapValue(result, "c2")};
567+
double c3Sum{readShapValue(result, "c3")};
568+
double c4Sum{readShapValue(result, "c4")};
569+
// We should have at least one feature that is important
570+
BOOST_TEST_REQUIRE((c1Sum > 0.0 || c2Sum > 0.0 || c3Sum > 0.0 || c4Sum > 0.0));
571+
572+
// class shap values should sum(abs()) to the overall feature importance
573+
double c1f{readShapValue(result, "c1", "foo")};
574+
double c1bar{readShapValue(result, "c1", "bar")};
575+
double c1baz{readShapValue(result, "c1", "baz")};
576+
BOOST_REQUIRE_CLOSE(
577+
c1Sum, std::abs(c1f) + std::abs(c1bar) + std::abs(c1baz), 1e-6);
578+
579+
double c2f{readShapValue(result, "c2", "foo")};
580+
double c2bar{readShapValue(result, "c2", "bar")};
581+
double c2baz{readShapValue(result, "c2", "baz")};
582+
BOOST_REQUIRE_CLOSE(
583+
c2Sum, std::abs(c2f) + std::abs(c2bar) + std::abs(c2baz), 1e-6);
584+
585+
double c3f{readShapValue(result, "c3", "foo")};
586+
double c3bar{readShapValue(result, "c3", "bar")};
587+
double c3baz{readShapValue(result, "c3", "baz")};
588+
BOOST_REQUIRE_CLOSE(
589+
c3Sum, std::abs(c3f) + std::abs(c3bar) + std::abs(c3baz), 1e-6);
590+
591+
double c4f{readShapValue(result, "c4", "foo")};
592+
double c4bar{readShapValue(result, "c4", "bar")};
593+
double c4baz{readShapValue(result, "c4", "baz")};
594+
BOOST_REQUIRE_CLOSE(
595+
c4Sum, std::abs(c4f) + std::abs(c4bar) + std::abs(c4baz), 1e-6);
596+
}
597+
}
598+
}
599+
424600
BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoShap, SFixture) {
425601
// Test that if topShapValue is set to 0, no feature importance values are returned.
426602
std::size_t topShapValues{0};
427603
auto results{runRegression(topShapValues, {50.0, 150.0, 50.0, -50.0})};
428604

429605
for (const auto& result : results.GetArray()) {
430606
if (result.HasMember("row_results")) {
431-
BOOST_TEST_REQUIRE(
432-
result["row_results"]["results"]["ml"].HasMember(
433-
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c1") == false);
434-
BOOST_TEST_REQUIRE(
435-
result["row_results"]["results"]["ml"].HasMember(
436-
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c2") == false);
437-
BOOST_TEST_REQUIRE(
438-
result["row_results"]["results"]["ml"].HasMember(
439-
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c3") == false);
440-
BOOST_TEST_REQUIRE(
441-
result["row_results"]["results"]["ml"].HasMember(
442-
maths::CTreeShapFeatureImportance::SHAP_PREFIX + "c4") == false);
607+
BOOST_TEST_REQUIRE(result["row_results"]["results"]["ml"].HasMember(
608+
api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME) ==
609+
false);
443610
}
444611
}
445612
}

lib/api/unittest/CDataFrameAnalyzerTrainingTest.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ BOOST_AUTO_TEST_CASE(testRunBoostedTreeClassifierTraining) {
587587
api::CDataFrameAnalyzer analyzer{
588588
specFactory.memoryLimit(6000000)
589589
.predictionCategoricalFieldNames({"target"})
590+
.numberTopClasses(1)
590591
.predictionSpec(test::CDataFrameAnalysisSpecificationFactory::classification(), "target"),
591592
outputWriterFactory};
592593
test::CDataFrameAnalyzerTrainingFactory::addPredictionTestData(

0 commit comments

Comments
 (0)