Skip to content

Commit e169729

Browse files
authored
[7.x][ML] Return total SHAP per feature as a new result type (#1455)
This PR add computation of the total feature importance values. Backport of #1387.
1 parent 1704ece commit e169729

16 files changed

+517
-35
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
regression. (See {ml-pull}1340[#1340].)
6565
* Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].)
6666
* Add a peak_model_bytes field to model_size_stats. (See {ml-pull}1389[#1389].)
67+
* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].)
6768

6869
=== Bug Fixes
6970

include/api/CDataFrameAnalysisRunner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212

1313
#include <api/CDataFrameAnalysisInstrumentation.h>
1414
#include <api/CInferenceModelDefinition.h>
15+
#include <api/CInferenceModelMetadata.h>
1516
#include <api/ImportExport.h>
1617

1718
#include <rapidjson/fwd.h>
1819

20+
#include <boost/optional.hpp>
21+
1922
#include <cstddef>
2023
#include <functional>
2124
#include <memory>
@@ -66,6 +69,7 @@ class API_EXPORT CDataFrameAnalysisRunner {
6669
using TProgressRecorder = std::function<void(double)>;
6770
using TStrVecVec = std::vector<TStrVec>;
6871
using TInferenceModelDefinitionUPtr = std::unique_ptr<CInferenceModelDefinition>;
72+
using TOptionalInferenceModelMetadata = boost::optional<const CInferenceModelMetadata&>;
6973

7074
public:
7175
//! The intention is that concrete objects of this hierarchy are constructed
@@ -141,6 +145,9 @@ class API_EXPORT CDataFrameAnalysisRunner {
141145
virtual TInferenceModelDefinitionUPtr
142146
inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNames) const;
143147

148+
//! \return A serialisable metadata of the trained model.
149+
virtual TOptionalInferenceModelMetadata inferenceModelMetadata() const;
150+
144151
//! \return Reference to the analysis instrumentation.
145152
virtual const CDataFrameAnalysisInstrumentation& instrumentation() const = 0;
146153
//! \return Reference to the analysis instrumentation.

include/api/CDataFrameAnalyzer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class API_EXPORT CDataFrameAnalyzer {
8787
core::CRapidJsonConcurrentLineWriter& writer) const;
8888
void writeInferenceModel(const CDataFrameAnalysisRunner& analysis,
8989
core::CRapidJsonConcurrentLineWriter& writer) const;
90+
void writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
91+
core::CRapidJsonConcurrentLineWriter& writer) const;
9092

9193
private:
9294
// This has values: -2 (unset), -1 (missing), >= 0 (control field index).

include/api/CDataFrameTrainBoostedTreeClassifierRunner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <core/CSmallVector.h>
1111

1212
#include <api/CDataFrameTrainBoostedTreeRunner.h>
13+
#include <api/CInferenceModelMetadata.h>
1314
#include <api/ImportExport.h>
1415

1516
#include <rapidjson/fwd.h>
@@ -40,6 +41,8 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
4041
static const std::string NUM_TOP_CLASSES;
4142
static const std::string PREDICTION_FIELD_TYPE;
4243
static const std::string CLASS_ASSIGNMENT_OBJECTIVE;
44+
static const std::string CLASSES_FIELD_NAME;
45+
static const std::string CLASS_NAME_FIELD_NAME;
4346
static const TStrVec CLASS_ASSIGNMENT_OBJECTIVE_VALUES;
4447

4548
public:
@@ -70,6 +73,9 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
7073
inferenceModelDefinition(const TStrVec& fieldNames,
7174
const TStrVecVec& categoryNames) const override;
7275

76+
//! \return A serialisable metadata of the trained regression model.
77+
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;
78+
7379
private:
7480
static TLossFunctionUPtr loss(std::size_t numberClasses);
7581

@@ -82,6 +88,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
8288
private:
8389
std::size_t m_NumTopClasses;
8490
EPredictionFieldType m_PredictionFieldType;
91+
mutable CInferenceModelMetadata m_InferenceModelMetadata;
8592
};
8693

8794
//! \brief Makes a core::CDataFrame boosted tree classification runner.

include/api/CDataFrameTrainBoostedTreeRegressionRunner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <maths/CBoostedTreeLoss.h>
1111

1212
#include <api/CDataFrameTrainBoostedTreeRunner.h>
13+
#include <api/CInferenceModelMetadata.h>
1314
#include <api/ImportExport.h>
1415

1516
#include <rapidjson/fwd.h>
@@ -51,10 +52,15 @@ class API_EXPORT CDataFrameTrainBoostedTreeRegressionRunner final
5152
TInferenceModelDefinitionUPtr
5253
inferenceModelDefinition(const TStrVec& fieldNames,
5354
const TStrVecVec& categoryNameMap) const override;
55+
//! \return A serialisable metadata of the trained regression model.
56+
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;
5457

5558
private:
5659
void validate(const core::CDataFrame& frame,
5760
std::size_t dependentVariableColumn) const override;
61+
62+
private:
63+
mutable CInferenceModelMetadata m_InferenceModelMetadata;
5864
};
5965

6066
//! \brief Makes a core::CDataFrame boosted tree regression runner.

include/api/CInferenceModelMetadata.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
#ifndef INCLUDED_ml_api_CInferenceModelMetadata_h
7+
#define INCLUDED_ml_api_CInferenceModelMetadata_h
8+
9+
#include <maths/CBasicStatistics.h>
10+
#include <maths/CLinearAlgebraEigen.h>
11+
12+
#include <api/CInferenceModelDefinition.h>
13+
#include <api/ImportExport.h>
14+
15+
#include <string>
16+
17+
namespace ml {
18+
namespace api {
19+
20+
//! \brief Class controls the serialization of the model meta information
21+
//! (such as totol feature importance) into JSON format.
22+
class API_EXPORT CInferenceModelMetadata {
23+
public:
24+
static const std::string JSON_CLASS_NAME_TAG;
25+
static const std::string JSON_CLASSES_TAG;
26+
static const std::string JSON_FEATURE_NAME_TAG;
27+
static const std::string JSON_IMPORTANCE_TAG;
28+
static const std::string JSON_MAX_TAG;
29+
static const std::string JSON_MEAN_MAGNITUDE_TAG;
30+
static const std::string JSON_MIN_TAG;
31+
static const std::string JSON_MODEL_METADATA_TAG;
32+
static const std::string JSON_TOTAL_FEATURE_IMPORTANCE_TAG;
33+
34+
public:
35+
using TVector = maths::CDenseVector<double>;
36+
using TStrVec = std::vector<std::string>;
37+
using TRapidJsonWriter = core::CRapidJsonConcurrentLineWriter;
38+
39+
public:
40+
//! Writes metadata using \p writer.
41+
void write(TRapidJsonWriter& writer) const;
42+
void columnNames(const TStrVec& columnNames);
43+
void classValues(const TStrVec& classValues);
44+
const std::string& typeString() const;
45+
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
46+
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
47+
void addToFeatureImportance(std::size_t i, const TVector& values);
48+
49+
private:
50+
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<TVector>::TAccumulator;
51+
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
52+
using TSizeMeanVarAccumulatorUMap = std::unordered_map<std::size_t, TMeanVarAccumulator>;
53+
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;
54+
55+
private:
56+
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;
57+
58+
private:
59+
TSizeMeanVarAccumulatorUMap m_TotalShapValuesMeanVar;
60+
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
61+
TStrVec m_ColumnNames;
62+
TStrVec m_ClassValues;
63+
};
64+
}
65+
}
66+
67+
#endif //INCLUDED_ml_api_CInferenceModelMetadata_h

include/maths/CBasicStatistics.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class MATHS_EXPORT CBasicStatistics {
245245

246246
if (ORDER > 1) {
247247
T r{x - s_Moments[0]};
248-
T r2{r * r};
248+
T r2{las::componentwise(r) * las::componentwise(r)};
249249
T dMean{mean - s_Moments[0]};
250250
T dMean2{las::componentwise(dMean) * las::componentwise(dMean)};
251251
T variance{s_Moments[1]};

include/maths/CTreeShapFeatureImportance.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
7373
//! Get the maximum depth of any tree in \p forest.
7474
static std::size_t depth(const TTreeVec& forest);
7575

76+
//! Get the column names.
77+
const TStrVec& columnNames() const;
78+
7679
private:
7780
//! Collects the elements of the path through decision tree that are updated together
7881
struct SPathElement {

lib/api/CDataFrameAnalysisRunner.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ CDataFrameAnalysisRunner::inferenceModelDefinition(const TStrVec& /*fieldNames*/
193193
return TInferenceModelDefinitionUPtr();
194194
}
195195

196+
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
197+
CDataFrameAnalysisRunner::inferenceModelMetadata() const {
198+
return TOptionalInferenceModelMetadata();
199+
}
200+
196201
CDataFrameAnalysisRunnerFactory::TRunnerUPtr
197202
CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spec) const {
198203
auto result = this->makeImpl(spec);

lib/api/CDataFrameAnalyzer.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ void CDataFrameAnalyzer::run() {
144144
analysisRunner->waitToFinish();
145145
this->writeInferenceModel(*analysisRunner, outputWriter);
146146
this->writeResultsOf(*analysisRunner, outputWriter);
147+
// TODO reactivate once Java parsing is ready
148+
// this->writeInferenceModelMetadata(*analysisRunner, outputWriter);
147149
}
148150
}
149151

@@ -286,6 +288,21 @@ void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& ana
286288
writer.flush();
287289
}
288290

291+
void CDataFrameAnalyzer::writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
292+
core::CRapidJsonConcurrentLineWriter& writer) const {
293+
// Write model meta information
294+
auto modelMetadata = analysis.inferenceModelMetadata();
295+
if (modelMetadata) {
296+
writer.StartObject();
297+
writer.Key(modelMetadata->typeString());
298+
writer.StartObject();
299+
modelMetadata->write(writer);
300+
writer.EndObject();
301+
writer.EndObject();
302+
}
303+
writer.flush();
304+
}
305+
289306
void CDataFrameAnalyzer::writeResultsOf(const CDataFrameAnalysisRunner& analysis,
290307
core::CRapidJsonConcurrentLineWriter& writer) const {
291308

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <maths/CBoostedTreeLoss.h>
1616
#include <maths/CDataFramePredictiveModel.h>
1717
#include <maths/CDataFrameUtils.h>
18+
#include <maths/CLinearAlgebraEigen.h>
1819
#include <maths/COrderings.h>
1920
#include <maths/CTools.h>
2021
#include <maths/CTreeShapFeatureImportance.h>
@@ -41,7 +42,6 @@ const std::string IS_TRAINING_FIELD_NAME{"is_training"};
4142
const std::string PREDICTION_PROBABILITY_FIELD_NAME{"prediction_probability"};
4243
const std::string PREDICTION_SCORE_FIELD_NAME{"prediction_score"};
4344
const std::string TOP_CLASSES_FIELD_NAME{"top_classes"};
44-
const std::string CLASS_NAME_FIELD_NAME{"class_name"};
4545
const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"};
4646
const std::string CLASS_SCORE_FIELD_NAME{"class_score"};
4747

@@ -162,7 +162,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
162162
}
163163

164164
if (featureImportance != nullptr) {
165-
int numberClasses{static_cast<int>(classValues.size())};
165+
std::size_t numberClasses{classValues.size()};
166+
m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
167+
m_InferenceModelMetadata.classValues(classValues);
166168
featureImportance->shap(
167169
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
168170
const TStrVec& featureNames,
@@ -175,20 +177,47 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
175177
writer.Key(FEATURE_NAME_FIELD_NAME);
176178
writer.String(featureNames[i]);
177179
if (shap[i].size() == 1) {
178-
writer.Key(IMPORTANCE_FIELD_NAME);
179-
writer.Double(shap[i](0));
180+
// output feature importance for individual classes in binary case
181+
writer.Key(CLASSES_FIELD_NAME);
182+
writer.StartArray();
183+
for (std::size_t j = 0; j < numberClasses; ++j) {
184+
double importance{(j == predictedClassId)
185+
? shap[i](0)
186+
: -shap[i](0)};
187+
writer.StartObject();
188+
writer.Key(CLASS_NAME_FIELD_NAME);
189+
writer.String(classValues[j]);
190+
writer.Key(IMPORTANCE_FIELD_NAME);
191+
writer.Double(importance);
192+
writer.EndObject();
193+
}
194+
writer.EndArray();
180195
} else {
181-
for (int j = 0; j < shap[i].size() && j < numberClasses; ++j) {
182-
writer.Key(classValues[j]);
196+
// output feature importance for individual classes in multiclass case
197+
writer.Key(CLASSES_FIELD_NAME);
198+
writer.StartArray();
199+
for (std::size_t j = 0;
200+
j < shap[i].size() && j < numberClasses; ++j) {
201+
writer.StartObject();
202+
writer.Key(CLASS_NAME_FIELD_NAME);
203+
writer.String(classValues[j]);
204+
writer.Key(IMPORTANCE_FIELD_NAME);
183205
writer.Double(shap[i](j));
206+
writer.EndObject();
184207
}
185-
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
186-
writer.Double(shap[i].lpNorm<1>());
208+
writer.EndArray();
187209
}
188210
writer.EndObject();
189211
}
190212
}
191213
writer.EndArray();
214+
215+
for (std::size_t i = 0; i < shap.size(); ++i) {
216+
if (shap[i].lpNorm<1>() != 0) {
217+
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
218+
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
219+
}
220+
}
192221
});
193222
}
194223
writer.EndObject();
@@ -257,6 +286,11 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(
257286
return std::make_unique<CInferenceModelDefinition>(builder.build());
258287
}
259288

289+
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
290+
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const {
291+
return m_InferenceModelMetadata;
292+
}
293+
260294
// clang-format off
261295
// The MAX_NUMBER_CLASSES must match the value used in the Java code. See the
262296
// MAX_DEPENDENT_VARIABLE_CARDINALITY in the x-pack classification code.
@@ -291,5 +325,7 @@ CDataFrameTrainBoostedTreeClassifierRunnerFactory::makeImpl(
291325
}
292326

293327
const std::string CDataFrameTrainBoostedTreeClassifierRunnerFactory::NAME{"classification"};
328+
const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSES_FIELD_NAME{"classes"};
329+
const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASS_NAME_FIELD_NAME{"class_name"};
294330
}
295331
}

lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
109109
writer.Bool(maths::CDataFrameUtils::isMissing(row[columnHoldingDependentVariable]) == false);
110110
auto featureImportance = tree.shap();
111111
if (featureImportance != nullptr) {
112+
m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
112113
featureImportance->shap(
113-
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
114-
const TStrVec& featureNames,
115-
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
114+
row, [&writer, this](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
115+
const TStrVec& featureNames,
116+
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
116117
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
117118
writer.StartArray();
118119
for (auto i : indices) {
@@ -126,6 +127,13 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
126127
}
127128
}
128129
writer.EndArray();
130+
131+
for (int i = 0; i < shap.size(); ++i) {
132+
if (shap[i].lpNorm<1>() != 0) {
133+
const_cast<CDataFrameTrainBoostedTreeRegressionRunner*>(this)
134+
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
135+
}
136+
}
129137
});
130138
}
131139
writer.EndObject();
@@ -145,6 +153,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(
145153
return std::make_unique<CInferenceModelDefinition>(builder.build());
146154
}
147155

156+
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
157+
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
158+
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
159+
}
160+
148161
// clang-format off
149162
const std::string CDataFrameTrainBoostedTreeRegressionRunner::STRATIFIED_CROSS_VALIDATION{"stratified_cross_validation"};
150163
const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION{"loss_function"};
@@ -160,7 +173,7 @@ const std::string& CDataFrameTrainBoostedTreeRegressionRunnerFactory::name() con
160173

161174
CDataFrameTrainBoostedTreeRegressionRunnerFactory::TRunnerUPtr
162175
CDataFrameTrainBoostedTreeRegressionRunnerFactory::makeImpl(const CDataFrameAnalysisSpecification&) const {
163-
HANDLE_FATAL(<< "Input error: classification has a non-optional parameter '"
176+
HANDLE_FATAL(<< "Input error: regression has a non-optional parameter '"
164177
<< CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME << "'.")
165178
return nullptr;
166179
}

0 commit comments

Comments
 (0)