Skip to content

Commit 3f1b575

Browse files
authored
[ML] Return total SHAP per feature as a new result type (#1387)
This PR add computation of the total feature importance values.
1 parent 3b7df50 commit 3f1b575

16 files changed

+517
-35
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
regression. (See {ml-pull}1340[#1340].)
7676
* Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].)
7777
* Add a peak_model_bytes field to model_size_stats. (See {ml-pull}1389[#1389].)
78+
* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].)
7879

7980
=== Bug Fixes
8081

include/api/CDataFrameAnalysisRunner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212

1313
#include <api/CDataFrameAnalysisInstrumentation.h>
1414
#include <api/CInferenceModelDefinition.h>
15+
#include <api/CInferenceModelMetadata.h>
1516
#include <api/ImportExport.h>
1617

1718
#include <rapidjson/fwd.h>
1819

20+
#include <boost/optional.hpp>
21+
1922
#include <cstddef>
2023
#include <functional>
2124
#include <memory>
@@ -66,6 +69,7 @@ class API_EXPORT CDataFrameAnalysisRunner {
6669
using TProgressRecorder = std::function<void(double)>;
6770
using TStrVecVec = std::vector<TStrVec>;
6871
using TInferenceModelDefinitionUPtr = std::unique_ptr<CInferenceModelDefinition>;
72+
using TOptionalInferenceModelMetadata = boost::optional<const CInferenceModelMetadata&>;
6973

7074
public:
7175
//! The intention is that concrete objects of this hierarchy are constructed
@@ -141,6 +145,9 @@ class API_EXPORT CDataFrameAnalysisRunner {
141145
virtual TInferenceModelDefinitionUPtr
142146
inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNames) const;
143147

148+
//! \return A serialisable metadata of the trained model.
149+
virtual TOptionalInferenceModelMetadata inferenceModelMetadata() const;
150+
144151
//! \return Reference to the analysis instrumentation.
145152
virtual const CDataFrameAnalysisInstrumentation& instrumentation() const = 0;
146153
//! \return Reference to the analysis instrumentation.

include/api/CDataFrameAnalyzer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class API_EXPORT CDataFrameAnalyzer {
8787
core::CRapidJsonConcurrentLineWriter& writer) const;
8888
void writeInferenceModel(const CDataFrameAnalysisRunner& analysis,
8989
core::CRapidJsonConcurrentLineWriter& writer) const;
90+
void writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
91+
core::CRapidJsonConcurrentLineWriter& writer) const;
9092

9193
private:
9294
// This has values: -2 (unset), -1 (missing), >= 0 (control field index).

include/api/CDataFrameTrainBoostedTreeClassifierRunner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <core/CSmallVector.h>
1111

1212
#include <api/CDataFrameTrainBoostedTreeRunner.h>
13+
#include <api/CInferenceModelMetadata.h>
1314
#include <api/ImportExport.h>
1415

1516
#include <rapidjson/fwd.h>
@@ -40,6 +41,8 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
4041
static const std::string NUM_TOP_CLASSES;
4142
static const std::string PREDICTION_FIELD_TYPE;
4243
static const std::string CLASS_ASSIGNMENT_OBJECTIVE;
44+
static const std::string CLASSES_FIELD_NAME;
45+
static const std::string CLASS_NAME_FIELD_NAME;
4346
static const TStrVec CLASS_ASSIGNMENT_OBJECTIVE_VALUES;
4447

4548
public:
@@ -70,6 +73,9 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
7073
inferenceModelDefinition(const TStrVec& fieldNames,
7174
const TStrVecVec& categoryNames) const override;
7275

76+
//! \return A serialisable metadata of the trained regression model.
77+
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;
78+
7379
private:
7480
static TLossFunctionUPtr loss(std::size_t numberClasses);
7581

@@ -82,6 +88,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
8288
private:
8389
std::size_t m_NumTopClasses;
8490
EPredictionFieldType m_PredictionFieldType;
91+
mutable CInferenceModelMetadata m_InferenceModelMetadata;
8592
};
8693

8794
//! \brief Makes a core::CDataFrame boosted tree classification runner.

include/api/CDataFrameTrainBoostedTreeRegressionRunner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <maths/CBoostedTreeLoss.h>
1111

1212
#include <api/CDataFrameTrainBoostedTreeRunner.h>
13+
#include <api/CInferenceModelMetadata.h>
1314
#include <api/ImportExport.h>
1415

1516
#include <rapidjson/fwd.h>
@@ -51,10 +52,15 @@ class API_EXPORT CDataFrameTrainBoostedTreeRegressionRunner final
5152
TInferenceModelDefinitionUPtr
5253
inferenceModelDefinition(const TStrVec& fieldNames,
5354
const TStrVecVec& categoryNameMap) const override;
55+
//! \return A serialisable metadata of the trained regression model.
56+
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;
5457

5558
private:
5659
void validate(const core::CDataFrame& frame,
5760
std::size_t dependentVariableColumn) const override;
61+
62+
private:
63+
mutable CInferenceModelMetadata m_InferenceModelMetadata;
5864
};
5965

6066
//! \brief Makes a core::CDataFrame boosted tree regression runner.

include/api/CInferenceModelMetadata.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
#ifndef INCLUDED_ml_api_CInferenceModelMetadata_h
7+
#define INCLUDED_ml_api_CInferenceModelMetadata_h
8+
9+
#include <maths/CBasicStatistics.h>
10+
#include <maths/CLinearAlgebraEigen.h>
11+
12+
#include <api/CInferenceModelDefinition.h>
13+
#include <api/ImportExport.h>
14+
15+
#include <string>
16+
17+
namespace ml {
18+
namespace api {
19+
20+
//! \brief Class controls the serialization of the model meta information
21+
//! (such as totol feature importance) into JSON format.
22+
class API_EXPORT CInferenceModelMetadata {
23+
public:
24+
static const std::string JSON_CLASS_NAME_TAG;
25+
static const std::string JSON_CLASSES_TAG;
26+
static const std::string JSON_FEATURE_NAME_TAG;
27+
static const std::string JSON_IMPORTANCE_TAG;
28+
static const std::string JSON_MAX_TAG;
29+
static const std::string JSON_MEAN_MAGNITUDE_TAG;
30+
static const std::string JSON_MIN_TAG;
31+
static const std::string JSON_MODEL_METADATA_TAG;
32+
static const std::string JSON_TOTAL_FEATURE_IMPORTANCE_TAG;
33+
34+
public:
35+
using TVector = maths::CDenseVector<double>;
36+
using TStrVec = std::vector<std::string>;
37+
using TRapidJsonWriter = core::CRapidJsonConcurrentLineWriter;
38+
39+
public:
40+
//! Writes metadata using \p writer.
41+
void write(TRapidJsonWriter& writer) const;
42+
void columnNames(const TStrVec& columnNames);
43+
void classValues(const TStrVec& classValues);
44+
const std::string& typeString() const;
45+
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
46+
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
47+
void addToFeatureImportance(std::size_t i, const TVector& values);
48+
49+
private:
50+
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<TVector>::TAccumulator;
51+
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
52+
using TSizeMeanVarAccumulatorUMap = std::unordered_map<std::size_t, TMeanVarAccumulator>;
53+
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;
54+
55+
private:
56+
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;
57+
58+
private:
59+
TSizeMeanVarAccumulatorUMap m_TotalShapValuesMeanVar;
60+
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
61+
TStrVec m_ColumnNames;
62+
TStrVec m_ClassValues;
63+
};
64+
}
65+
}
66+
67+
#endif //INCLUDED_ml_api_CInferenceModelMetadata_h

include/maths/CBasicStatistics.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class MATHS_EXPORT CBasicStatistics {
245245

246246
if (ORDER > 1) {
247247
T r{x - s_Moments[0]};
248-
T r2{r * r};
248+
T r2{las::componentwise(r) * las::componentwise(r)};
249249
T dMean{mean - s_Moments[0]};
250250
T dMean2{las::componentwise(dMean) * las::componentwise(dMean)};
251251
T variance{s_Moments[1]};

include/maths/CTreeShapFeatureImportance.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
7373
//! Get the maximum depth of any tree in \p forest.
7474
static std::size_t depth(const TTreeVec& forest);
7575

76+
//! Get the column names.
77+
const TStrVec& columnNames() const;
78+
7679
private:
7780
//! Collects the elements of the path through decision tree that are updated together
7881
struct SPathElement {

lib/api/CDataFrameAnalysisRunner.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ CDataFrameAnalysisRunner::inferenceModelDefinition(const TStrVec& /*fieldNames*/
193193
return TInferenceModelDefinitionUPtr();
194194
}
195195

196+
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
197+
CDataFrameAnalysisRunner::inferenceModelMetadata() const {
198+
return TOptionalInferenceModelMetadata();
199+
}
200+
196201
CDataFrameAnalysisRunnerFactory::TRunnerUPtr
197202
CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spec) const {
198203
auto result = this->makeImpl(spec);

lib/api/CDataFrameAnalyzer.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ void CDataFrameAnalyzer::run() {
144144
analysisRunner->waitToFinish();
145145
this->writeInferenceModel(*analysisRunner, outputWriter);
146146
this->writeResultsOf(*analysisRunner, outputWriter);
147+
// TODO reactivate once Java parsing is ready
148+
// this->writeInferenceModelMetadata(*analysisRunner, outputWriter);
147149
}
148150
}
149151

@@ -286,6 +288,21 @@ void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& ana
286288
writer.flush();
287289
}
288290

291+
void CDataFrameAnalyzer::writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
292+
core::CRapidJsonConcurrentLineWriter& writer) const {
293+
// Write model meta information
294+
auto modelMetadata = analysis.inferenceModelMetadata();
295+
if (modelMetadata) {
296+
writer.StartObject();
297+
writer.Key(modelMetadata->typeString());
298+
writer.StartObject();
299+
modelMetadata->write(writer);
300+
writer.EndObject();
301+
writer.EndObject();
302+
}
303+
writer.flush();
304+
}
305+
289306
void CDataFrameAnalyzer::writeResultsOf(const CDataFrameAnalysisRunner& analysis,
290307
core::CRapidJsonConcurrentLineWriter& writer) const {
291308

0 commit comments

Comments
 (0)