Skip to content

Commit 309db87

Browse files
authored
[ML] Add information about samples per node to the tree (#991)
This PR extends the definition of the tree node by adding information about the number of training samples that passed through the node (numberSamples or number_samples). The json schema for inference model is adjusted accordingly. Since this change the schema for persist/restore of the tree implementation, I bumped the version and removed 7.5 and 7.6 from the list of supported version. My reasoning: restoring from old schema and setting number samples to 0 would break feature importance at inference time. I also adjust feature importance computation to use pre-computed number samples instead of recomputing it on the fly.
1 parent 2916819 commit 309db87

16 files changed

+215
-139
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ progress, memory usage, etc. (See {ml-pull}906[#906].)
4646

4747
* Improve initialization of learn rate for better and more stable results in regression
4848
and classification. (See {ml-pull}948[#948].)
49+
* Add number of processed training samples to the definition of decision tree nodes.
50+
(See {ml-pull}991[#991].)
4951
* Add new model_size_stats fields to instrument categorization. (See {ml-pull}948[#948]
5052
and {pull}51879[#51879], issue: {issue}50794[#50749].)
5153

include/api/CBoostedTreeInferenceModelBuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class API_EXPORT CBoostedTreeInferenceModelBuilder : public maths::CBoostedTree:
3838
bool assignMissingToLeft,
3939
double nodeValue,
4040
double gain,
41+
std::size_t numberSamples,
4142
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
4243
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) override;
4344
void addIdentityEncoding(std::size_t inputColumnIndex) override;

include/api/CInferenceModelDefinition.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ class API_EXPORT CTree final : public CTrainedModel {
158158
double threshold,
159159
bool defaultLeft,
160160
double leafValue,
161-
size_t splitFeature,
161+
std::size_t splitFeature,
162+
std::size_t numberSamples,
162163
const TOptionalNodeIndex& leftChild,
163164
const TOptionalNodeIndex& rightChild,
164165
const TOptionalDouble& splitGain);
@@ -175,6 +176,7 @@ class API_EXPORT CTree final : public CTrainedModel {
175176
TOptionalNodeIndex m_LeftChild;
176177
TOptionalNodeIndex m_RightChild;
177178
std::size_t m_SplitFeature;
179+
std::size_t m_NumberSamples;
178180
double m_Threshold;
179181
double m_LeafValue;
180182
TOptionalDouble m_SplitGain;

include/maths/CBoostedTree.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
293293
bool assignMissingToLeft,
294294
double nodeValue,
295295
double gain,
296+
std::size_t numberSamples,
296297
TOptionalNodeIndex leftChild,
297298
TOptionalNodeIndex rightChild) = 0;
298299
};
@@ -334,6 +335,12 @@ class MATHS_EXPORT CBoostedTreeNode final {
334335
//! Get the total curvature at the rows below this node.
335336
double curvature() const { return m_Curvature; }
336337

338+
//! Set the number of samples to \p value.
339+
void numberSamples(std::size_t value);
340+
341+
//! Get number of samples affected by the node.
342+
std::size_t numberSamples() const;
343+
337344
//! Get the index of the left child node.
338345
TNodeIndex leftChildIndex() const { return m_LeftChild.get(); }
339346

@@ -376,6 +383,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
376383
double m_NodeValue = 0.0;
377384
double m_Gain = 0.0;
378385
double m_Curvature = 0.0;
386+
std::size_t m_NumberSamples = 0;
379387
};
380388

381389
//! \brief A boosted regression tree model.

include/maths/CBoostedTreeImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
252252
//! Get the root node of \p tree.
253253
static const CBoostedTreeNode& root(const TNodeVec& tree);
254254

255+
//! Get the root node of \p tree.
256+
static CBoostedTreeNode& root(TNodeVec& tree);
257+
255258
//! Get the forest's prediction for \p row.
256259
static double predictRow(const CEncodedDataFrameRowRef& row, const TNodeVecVec& forest);
257260

@@ -287,6 +290,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
287290
//! Record the training state using the \p recordTrainState callback function
288291
void recordState(const TTrainingStateCallback& recordTrainState) const;
289292

293+
//! Populate numberSamples field in the m_BestForest
294+
void computeNumberSamples(const core::CDataFrame& frame);
295+
290296
private:
291297
mutable CPRNG::CXorOShiro128Plus m_Rng;
292298
std::size_t m_NumberThreads;

include/maths/CTreeShapFeatureImportance.h

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
4040
//! by \p offset.
4141
void shap(core::CDataFrame& frame, const CDataFrameCategoryEncoder& encoder, std::size_t offset);
4242

43-
//! Compute number of training samples from \p frame that pass every node in the \p tree.
44-
static TDoubleVec samplesPerNode(const TTree& tree,
45-
const core::CDataFrame& frame,
46-
const CDataFrameCategoryEncoder& encoder,
47-
std::size_t numThreads);
48-
4943
//! Recursively computes inner node values as weighted average of the children (leaf) values
5044
//! \returns The maximum depth the the tree.
51-
static std::size_t updateNodeValues(TTree& tree,
52-
std::size_t nodeIndex,
53-
const TDoubleVec& samplesPerNode,
54-
std::size_t depth);
45+
static size_t updateNodeValues(TTree& tree, std::size_t nodeIndex, std::size_t depth);
5546

5647
//! Get the reference to the trees.
5748
TTreeVec& trees() { return m_Trees; }
@@ -126,7 +117,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
126117
//! Recursively traverses all pathes in the \p tree and updated SHAP values once it hits a leaf.
127118
//! Ref. Algorithm 2 in the paper by Lundberg et al.
128119
void shapRecursive(const TTree& tree,
129-
const TDoubleVec& samplesPerNode,
130120
const CDataFrameCategoryEncoder& encoder,
131121
const CEncodedDataFrameRowRef& encodedRow,
132122
SPath& splitPath,
@@ -146,7 +136,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
146136
private:
147137
TTreeVec m_Trees;
148138
std::size_t m_NumberThreads;
149-
TDoubleVecVec m_SamplesPerNode;
150139
};
151140
}
152141
}

lib/api/CBoostedTreeInferenceModelBuilder.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
8989
bool assignMissingToLeft,
9090
double nodeValue,
9191
double gain,
92+
std::size_t numberSamples,
9293
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
9394
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) {
9495
auto ensemble{static_cast<CEnsemble*>(m_Definition.trainedModel().get())};
@@ -97,8 +98,9 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
9798
if (tree == nullptr) {
9899
HANDLE_FATAL(<< "Internal error. Tree points to a nullptr.")
99100
}
100-
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft, nodeValue,
101-
splitFeature, leftChild, rightChild, gain);
101+
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft,
102+
nodeValue, splitFeature, numberSamples,
103+
leftChild, rightChild, gain);
102104
}
103105

104106
CBoostedTreeInferenceModelBuilder::CBoostedTreeInferenceModelBuilder(TStrVec fieldNames,

lib/api/CInferenceModelDefinition.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const std::string JSON_LEFT_CHILD_TAG{"left_child"};
3434
const std::string JSON_LOGISTIC_REGRESSION_TAG{"logistic_regression"};
3535
const std::string JSON_LT{"lt"};
3636
const std::string JSON_NODE_INDEX_TAG{"node_index"};
37+
const std::string JSON_NUMBER_SAMPLES_TAG{"number_samples"};
3738
const std::string JSON_ONE_HOT_ENCODING_TAG{"one_hot_encoding"};
3839
const std::string JSON_PREPROCESSORS_TAG{"preprocessors"};
3940
const std::string JSON_RIGHT_CHILD_TAG{"right_child"};
@@ -79,6 +80,9 @@ void addJsonArray(const std::string& tag,
7980
void CTree::CTreeNode::addToDocument(rapidjson::Value& parentObject,
8081
TRapidJsonWriter& writer) const {
8182
writer.addMember(JSON_NODE_INDEX_TAG, rapidjson::Value(m_NodeIndex).Move(), parentObject);
83+
writer.addMember(
84+
JSON_NUMBER_SAMPLES_TAG,
85+
rapidjson::Value(static_cast<std::uint64_t>(m_NumberSamples)).Move(), parentObject);
8286

8387
if (m_LeftChild) {
8488
// internal node
@@ -118,11 +122,13 @@ CTree::CTreeNode::CTreeNode(TNodeIndex nodeIndex,
118122
bool defaultLeft,
119123
double leafValue,
120124
std::size_t splitFeature,
125+
std::size_t numberSamples,
121126
const TOptionalNodeIndex& leftChild,
122127
const TOptionalNodeIndex& rightChild,
123128
const TOptionalDouble& splitGain)
124-
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex), m_LeftChild(leftChild),
125-
m_RightChild(rightChild), m_SplitFeature(splitFeature),
129+
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex),
130+
m_LeftChild(leftChild), m_RightChild(rightChild),
131+
m_SplitFeature(splitFeature), m_NumberSamples(numberSamples),
126132
m_Threshold(threshold), m_LeafValue(leafValue), m_SplitGain(splitGain) {
127133
}
128134

lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) {
299299
// c1 explains 95% of the prediction value, i.e. the difference from the prediction is less than 2%.
300300
BOOST_REQUIRE_CLOSE(c1, prediction, 5.0);
301301
for (const auto& feature : {"c2", "c3", "c4"}) {
302-
BOOST_REQUIRE_SMALL(readShapValue(result, feature), 2.0);
303-
cNoImportanceMean.add(std::fabs(readShapValue(result, feature)));
302+
double c = readShapValue(result, feature);
303+
BOOST_REQUIRE_SMALL(c, 2.0);
304+
cNoImportanceMean.add(std::fabs(c));
304305
}
305306
}
306307
}

lib/api/unittest/testfiles/inference_json_schema/model_definition.schema.json

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
},
6767
"right_child": {
6868
"type": "integer"
69+
},
70+
"number_samples": {
71+
"description": "Number of training samples that were affected by the node.",
72+
"type": "integer"
6973
}
7074
},
7175
"required": [
@@ -75,7 +79,8 @@
7579
"decision_type",
7680
"default_left",
7781
"left_child",
78-
"right_child"
82+
"right_child",
83+
"number_samples"
7984
],
8085
"additionalProperties": false
8186
},
@@ -88,11 +93,16 @@
8893
},
8994
"leaf_value": {
9095
"type": "number"
96+
},
97+
"number_samples": {
98+
"description": "Number of training samples that were affected by the node.",
99+
"type": "integer"
91100
}
92101
},
93102
"required": [
94103
"node_index",
95-
"leaf_value"
104+
"leaf_value",
105+
"number_samples"
96106
],
97107
"additionalProperties": false
98108
},
@@ -234,10 +244,14 @@
234244
"items": {
235245
"type": "number"
236246
}
247+
},
248+
"num_classes": {
249+
"type": "integer"
237250
}
238251
},
239252
"required": [
240-
"weights"
253+
"weights",
254+
"num_classes"
241255
]
242256
}
243257
},

lib/maths/CBoostedTree.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const std::string SPLIT_FEATURE_TAG{"split_feature"};
2727
const std::string ASSIGN_MISSING_TO_LEFT_TAG{"assign_missing_to_left "};
2828
const std::string NODE_VALUE_TAG{"node_value"};
2929
const std::string SPLIT_VALUE_TAG{"split_value"};
30+
const std::string NUMBER_SAMPLES_TAG{"number_samples"};
3031

3132
double LOG_EPSILON{std::log(100.0 * std::numeric_limits<double>::epsilon())};
3233

@@ -393,6 +394,7 @@ void CBoostedTreeNode::acceptPersistInserter(core::CStatePersistInserter& insert
393394
core::CPersistUtils::persist(ASSIGN_MISSING_TO_LEFT_TAG, m_AssignMissingToLeft, inserter);
394395
core::CPersistUtils::persist(NODE_VALUE_TAG, m_NodeValue, inserter);
395396
core::CPersistUtils::persist(SPLIT_VALUE_TAG, m_SplitValue, inserter);
397+
core::CPersistUtils::persist(NUMBER_SAMPLES_TAG, m_NumberSamples, inserter);
396398
}
397399

398400
bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
@@ -411,6 +413,8 @@ bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& trav
411413
core::CPersistUtils::restore(NODE_VALUE_TAG, m_NodeValue, traverser))
412414
RESTORE(SPLIT_VALUE_TAG,
413415
core::CPersistUtils::restore(SPLIT_VALUE_TAG, m_SplitValue, traverser))
416+
RESTORE(NUMBER_SAMPLES_TAG,
417+
core::CPersistUtils::restore(NUMBER_SAMPLES_TAG, m_NumberSamples, traverser))
414418
} while (traverser.next());
415419
return true;
416420
}
@@ -435,8 +439,16 @@ std::ostringstream& CBoostedTreeNode::doPrint(std::string pad,
435439
}
436440

437441
void CBoostedTreeNode::accept(CVisitor& visitor) const {
438-
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft,
439-
m_NodeValue, m_Gain, m_LeftChild, m_RightChild);
442+
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft, m_NodeValue,
443+
m_Gain, m_NumberSamples, m_LeftChild, m_RightChild);
444+
}
445+
446+
void CBoostedTreeNode::numberSamples(std::size_t numberSamples) {
447+
m_NumberSamples = numberSamples;
448+
}
449+
450+
std::size_t CBoostedTreeNode::numberSamples() const {
451+
return m_NumberSamples;
440452
}
441453

442454
CBoostedTree::CBoostedTree(core::CDataFrame& frame,

lib/maths/CBoostedTreeImpl.cc

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
226226
this->restoreBestHyperparameters();
227227
std::tie(m_BestForest, std::ignore) = this->trainForest(
228228
frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress);
229+
229230
m_Instrumentation->nextStep(static_cast<std::uint32_t>(m_CurrentRound));
230231
this->recordState(recordTrainStateCallback);
231232

@@ -242,12 +243,56 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
242243

243244
this->computeProbabilityAtWhichToAssignClassOne(frame);
244245

246+
// populate numberSamples field in the final forest
247+
this->computeNumberSamples(frame);
248+
245249
// Force progress to one because we can have early exit from loop skip altogether.
246250
m_Instrumentation->updateProgress(1.0);
247251
m_Instrumentation->updateMemoryUsage(
248252
static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
249253
}
250254

255+
void CBoostedTreeImpl::computeNumberSamples(const core::CDataFrame& frame) {
256+
for (auto& tree : m_BestForest) {
257+
if (tree.size() == 1) {
258+
root(tree).numberSamples(frame.numberRows());
259+
} else {
260+
auto result = frame.readRows(
261+
m_NumberThreads,
262+
core::bindRetrievableState(
263+
[&](TSizeVec& samplesPerNode, const TRowItr& beginRows, const TRowItr& endRows) {
264+
for (auto row = beginRows; row != endRows; ++row) {
265+
auto encodedRow{m_Encoder->encode(*row)};
266+
const CBoostedTreeNode* node{&root(tree)};
267+
samplesPerNode[0] += 1;
268+
std::size_t nextIndex;
269+
while (node->isLeaf() == false) {
270+
if (node->assignToLeft(encodedRow)) {
271+
nextIndex = node->leftChildIndex();
272+
} else {
273+
nextIndex = node->rightChildIndex();
274+
}
275+
samplesPerNode[nextIndex] += 1;
276+
node = &(tree[nextIndex]);
277+
}
278+
}
279+
},
280+
TSizeVec(tree.size())));
281+
auto& state = result.first;
282+
TSizeVec totalSamplesPerNode{std::move(state[0].s_FunctionState)};
283+
for (std::size_t i = 1; i < state.size(); ++i) {
284+
for (std::size_t nodeIndex = 0;
285+
nodeIndex < totalSamplesPerNode.size(); ++nodeIndex) {
286+
totalSamplesPerNode[nodeIndex] += state[i].s_FunctionState[nodeIndex];
287+
}
288+
}
289+
for (std::size_t i = 0; i < tree.size(); ++i) {
290+
tree[i].numberSamples(totalSamplesPerNode[i]);
291+
}
292+
}
293+
}
294+
}
295+
251296
void CBoostedTreeImpl::recordState(const TTrainingStateCallback& recordTrainState) const {
252297
recordTrainState([this](core::CStatePersistInserter& inserter) {
253298
this->acceptPersistInserter(inserter);
@@ -997,6 +1042,10 @@ const CBoostedTreeNode& CBoostedTreeImpl::root(const TNodeVec& tree) {
9971042
return tree[0];
9981043
}
9991044

1045+
CBoostedTreeNode& CBoostedTreeImpl::root(TNodeVec& tree) {
1046+
return tree[0];
1047+
}
1048+
10001049
double CBoostedTreeImpl::predictRow(const CEncodedDataFrameRowRef& row,
10011050
const TNodeVecVec& forest) {
10021051
double result{0.0};
@@ -1148,9 +1197,8 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const {
11481197
}
11491198

11501199
namespace {
1151-
const std::string VERSION_7_5_TAG{"7.5"};
1152-
const std::string VERSION_7_6_TAG{"7.6"};
1153-
const TStrVec SUPPORTED_VERSIONS{VERSION_7_5_TAG, VERSION_7_6_TAG};
1200+
const std::string VERSION_7_7_TAG{"7.7"};
1201+
const TStrVec SUPPORTED_VERSIONS{VERSION_7_7_TAG};
11541202

11551203
const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"};
11561204
const std::string BEST_FOREST_TAG{"best_forest"};
@@ -1214,7 +1262,7 @@ CBoostedTreeImpl::TStrVec CBoostedTreeImpl::bestHyperparameterNames() {
12141262
}
12151263

12161264
void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
1217-
core::CPersistUtils::persist(VERSION_7_6_TAG, "", inserter);
1265+
core::CPersistUtils::persist(VERSION_7_7_TAG, "", inserter);
12181266
core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter);
12191267
core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter);
12201268
core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter);
@@ -1266,15 +1314,7 @@ void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& insert
12661314
}
12671315

12681316
bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
1269-
if (traverser.name() == VERSION_7_5_TAG) {
1270-
// Force downsample factor to 1.0.
1271-
m_DownsampleFactorOverride = 1.0;
1272-
m_DownsampleFactor = 1.0;
1273-
m_BestHyperparameters.downsampleFactor(1.0);
1274-
// We can't stop cross-validation early because we haven't gathered the
1275-
// per fold test losses.
1276-
m_StopCrossValidationEarly = false;
1277-
} else if (traverser.name() != VERSION_7_6_TAG) {
1317+
if (traverser.name() != VERSION_7_7_TAG) {
12781318
LOG_ERROR(<< "Input error: unsupported state serialization version. "
12791319
<< "Currently supported versions: "
12801320
<< core::CContainerPrinter::print(SUPPORTED_VERSIONS) << ".");

0 commit comments

Comments
 (0)