Skip to content

[ML] Add information about samples per node to the tree #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
feacc51
additions to definition
valeriy42 Feb 5, 2020
714633f
Merge branch 'master' into ml-cpp-850
valeriy42 Feb 5, 2020
4245b52
wire through number of samples
valeriy42 Feb 5, 2020
2cc8b0c
unit test adjusted
valeriy42 Feb 6, 2020
e7229e1
refactoring, formatting
valeriy42 Feb 6, 2020
69f6650
move setter for number of samples
valeriy42 Feb 6, 2020
1d23037
add enhancement note
valeriy42 Feb 6, 2020
a85304d
bump version for persist/restore
valeriy42 Feb 6, 2020
ee63036
adjust SHAP algorithm to use precomputed number samples
valeriy42 Feb 6, 2020
736416a
add comments
valeriy42 Feb 6, 2020
f9edc6b
updated test bounds
valeriy42 Feb 7, 2020
bcadd65
rename variables for consistency
valeriy42 Feb 7, 2020
868d49f
version bump fixed
valeriy42 Feb 7, 2020
db799ed
Merge branch 'master' into ml-cpp-850
valeriy42 Feb 7, 2020
763bc4e
formatting
valeriy42 Feb 7, 2020
aaf79b7
clang warning fixed.
valeriy42 Feb 10, 2020
097dc25
samples per node computation as a standalone method
valeriy42 Feb 13, 2020
29dd782
formatting
valeriy42 Feb 14, 2020
6a61fa9
Merge branch 'master' of https://github.com/elastic/ml-cpp into ml-cp…
valeriy42 Feb 17, 2020
21e714d
explicit numberSamples vector removed
valeriy42 Feb 17, 2020
ed15070
Formatting
valeriy42 Feb 17, 2020
c30908e
changes in CBoostedTreeLeafNodeStatistics reverted
valeriy42 Feb 5, 2020
3fbb75e
use root() and fix conversions
valeriy42 Feb 17, 2020
93a4432
move number samples computations
valeriy42 Feb 17, 2020
f1cc310
fix for root method
valeriy42 Feb 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ progress, memory usage, etc. (See {ml-pull}906[#906].)

* Improve initialization of learn rate for better and more stable results in regression
and classification. (See {ml-pull}948[#948].)
* Add number of processed training samples to the definition of decision tree nodes.
(See {ml-pull}991[#991].)
* Add new model_size_stats fields to instrument categorization. (See {ml-pull}948[#948]
and {pull}51879[#51879], issue: {issue}50794[#50749].)

Expand Down
1 change: 1 addition & 0 deletions include/api/CBoostedTreeInferenceModelBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class API_EXPORT CBoostedTreeInferenceModelBuilder : public maths::CBoostedTree:
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) override;
void addIdentityEncoding(std::size_t inputColumnIndex) override;
Expand Down
4 changes: 3 additions & 1 deletion include/api/CInferenceModelDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class API_EXPORT CTree final : public CTrainedModel {
double threshold,
bool defaultLeft,
double leafValue,
size_t splitFeature,
std::size_t splitFeature,
std::size_t numberSamples,
const TOptionalNodeIndex& leftChild,
const TOptionalNodeIndex& rightChild,
const TOptionalDouble& splitGain);
Expand All @@ -175,6 +176,7 @@ class API_EXPORT CTree final : public CTrainedModel {
TOptionalNodeIndex m_LeftChild;
TOptionalNodeIndex m_RightChild;
std::size_t m_SplitFeature;
std::size_t m_NumberSamples;
double m_Threshold;
double m_LeafValue;
TOptionalDouble m_SplitGain;
Expand Down
8 changes: 8 additions & 0 deletions include/maths/CBoostedTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
TOptionalNodeIndex leftChild,
TOptionalNodeIndex rightChild) = 0;
};
Expand Down Expand Up @@ -306,6 +307,12 @@ class MATHS_EXPORT CBoostedTreeNode final {
//! Get the total curvature at the rows below this node.
double curvature() const { return m_Curvature; }

//! Set the number of samples to \p value.
void numberSamples(std::size_t value);

//! Get number of samples affected by the node.
std::size_t numberSamples() const;

//! Get the index of the left child node.
TNodeIndex leftChildIndex() const { return m_LeftChild.get(); }

Expand Down Expand Up @@ -348,6 +355,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
double m_NodeValue = 0.0;
double m_Gain = 0.0;
double m_Curvature = 0.0;
std::size_t m_NumberSamples = 0;
};

//! \brief A boosted regression tree model.
Expand Down
6 changes: 6 additions & 0 deletions include/maths/CBoostedTreeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! Get the root node of \p tree.
static const CBoostedTreeNode& root(const TNodeVec& tree);

//! Get the root node of \p tree.
static CBoostedTreeNode& root(TNodeVec& tree);

//! Get the forest's prediction for \p row.
static double predictRow(const CEncodedDataFrameRowRef& row, const TNodeVecVec& forest);

Expand Down Expand Up @@ -287,6 +290,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! Record the training state using the \p recordTrainState callback function
void recordState(const TTrainingStateCallback& recordTrainState) const;

//! Populate numberSamples field in the m_BestForest
void computeNumberSamples(const core::CDataFrame& frame);

private:
mutable CPRNG::CXorOShiro128Plus m_Rng;
std::size_t m_NumberThreads;
Expand Down
13 changes: 1 addition & 12 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! by \p offset.
void shap(core::CDataFrame& frame, const CDataFrameCategoryEncoder& encoder, std::size_t offset);

//! Compute number of training samples from \p frame that pass every node in the \p tree.
static TDoubleVec samplesPerNode(const TTree& tree,
const core::CDataFrame& frame,
const CDataFrameCategoryEncoder& encoder,
std::size_t numThreads);

//! Recursively computes inner node values as weighted average of the children (leaf) values
//! \returns The maximum depth the the tree.
static std::size_t updateNodeValues(TTree& tree,
std::size_t nodeIndex,
const TDoubleVec& samplesPerNode,
std::size_t depth);
static size_t updateNodeValues(TTree& tree, std::size_t nodeIndex, std::size_t depth);

//! Get the reference to the trees.
TTreeVec& trees() { return m_Trees; }
Expand Down Expand Up @@ -126,7 +117,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! Recursively traverses all pathes in the \p tree and updated SHAP values once it hits a leaf.
//! Ref. Algorithm 2 in the paper by Lundberg et al.
void shapRecursive(const TTree& tree,
const TDoubleVec& samplesPerNode,
const CDataFrameCategoryEncoder& encoder,
const CEncodedDataFrameRowRef& encodedRow,
SPath& splitPath,
Expand All @@ -146,7 +136,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
private:
TTreeVec m_Trees;
std::size_t m_NumberThreads;
TDoubleVecVec m_SamplesPerNode;
};
}
}
Expand Down
6 changes: 4 additions & 2 deletions lib/api/CBoostedTreeInferenceModelBuilder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) {
auto ensemble{static_cast<CEnsemble*>(m_Definition.trainedModel().get())};
Expand All @@ -97,8 +98,9 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
if (tree == nullptr) {
HANDLE_FATAL(<< "Internal error. Tree points to a nullptr.")
}
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft, nodeValue,
splitFeature, leftChild, rightChild, gain);
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft,
nodeValue, splitFeature, numberSamples,
leftChild, rightChild, gain);
}

CBoostedTreeInferenceModelBuilder::CBoostedTreeInferenceModelBuilder(TStrVec fieldNames,
Expand Down
10 changes: 8 additions & 2 deletions lib/api/CInferenceModelDefinition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const std::string JSON_LEFT_CHILD_TAG{"left_child"};
const std::string JSON_LOGISTIC_REGRESSION_TAG{"logistic_regression"};
const std::string JSON_LT{"lt"};
const std::string JSON_NODE_INDEX_TAG{"node_index"};
const std::string JSON_NUMBER_SAMPLES_TAG{"number_samples"};
const std::string JSON_ONE_HOT_ENCODING_TAG{"one_hot_encoding"};
const std::string JSON_PREPROCESSORS_TAG{"preprocessors"};
const std::string JSON_RIGHT_CHILD_TAG{"right_child"};
Expand Down Expand Up @@ -79,6 +80,9 @@ void addJsonArray(const std::string& tag,
void CTree::CTreeNode::addToDocument(rapidjson::Value& parentObject,
TRapidJsonWriter& writer) const {
writer.addMember(JSON_NODE_INDEX_TAG, rapidjson::Value(m_NodeIndex).Move(), parentObject);
writer.addMember(
JSON_NUMBER_SAMPLES_TAG,
rapidjson::Value(static_cast<std::uint64_t>(m_NumberSamples)).Move(), parentObject);

if (m_LeftChild) {
// internal node
Expand Down Expand Up @@ -118,11 +122,13 @@ CTree::CTreeNode::CTreeNode(TNodeIndex nodeIndex,
bool defaultLeft,
double leafValue,
std::size_t splitFeature,
std::size_t numberSamples,
const TOptionalNodeIndex& leftChild,
const TOptionalNodeIndex& rightChild,
const TOptionalDouble& splitGain)
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex), m_LeftChild(leftChild),
m_RightChild(rightChild), m_SplitFeature(splitFeature),
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex),
m_LeftChild(leftChild), m_RightChild(rightChild),
m_SplitFeature(splitFeature), m_NumberSamples(numberSamples),
m_Threshold(threshold), m_LeafValue(leafValue), m_SplitGain(splitGain) {
}

Expand Down
5 changes: 3 additions & 2 deletions lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,9 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) {
// c1 explains 95% of the prediction value, i.e. the difference from the prediction is less than 2%.
BOOST_REQUIRE_CLOSE(c1, prediction, 5.0);
for (const auto& feature : {"c2", "c3", "c4"}) {
BOOST_REQUIRE_SMALL(readShapValue(result, feature), 2.0);
cNoImportanceMean.add(std::fabs(readShapValue(result, feature)));
double c = readShapValue(result, feature);
BOOST_REQUIRE_SMALL(c, 2.0);
cNoImportanceMean.add(std::fabs(c));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
},
"right_child": {
"type": "integer"
},
"number_samples": {
"description": "Number of training samples that were affected by the node.",
"type": "integer"
}
},
"required": [
Expand All @@ -75,7 +79,8 @@
"decision_type",
"default_left",
"left_child",
"right_child"
"right_child",
"number_samples"
],
"additionalProperties": false
},
Expand All @@ -88,11 +93,16 @@
},
"leaf_value": {
"type": "number"
},
"number_samples": {
"description": "Number of training samples that were affected by the node.",
"type": "integer"
}
},
"required": [
"node_index",
"leaf_value"
"leaf_value",
"number_samples"
],
"additionalProperties": false
},
Expand Down Expand Up @@ -234,10 +244,14 @@
"items": {
"type": "number"
}
},
"num_classes": {
"type": "integer"
}
},
"required": [
"weights"
"weights",
"num_classes"
]
}
},
Expand Down
16 changes: 14 additions & 2 deletions lib/maths/CBoostedTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const std::string SPLIT_FEATURE_TAG{"split_feature"};
const std::string ASSIGN_MISSING_TO_LEFT_TAG{"assign_missing_to_left "};
const std::string NODE_VALUE_TAG{"node_value"};
const std::string SPLIT_VALUE_TAG{"split_value"};
const std::string NUMBER_SAMPLES_TAG{"number_samples"};

double LOG_EPSILON{std::log(100.0 * std::numeric_limits<double>::epsilon())};

Expand Down Expand Up @@ -370,6 +371,7 @@ void CBoostedTreeNode::acceptPersistInserter(core::CStatePersistInserter& insert
core::CPersistUtils::persist(ASSIGN_MISSING_TO_LEFT_TAG, m_AssignMissingToLeft, inserter);
core::CPersistUtils::persist(NODE_VALUE_TAG, m_NodeValue, inserter);
core::CPersistUtils::persist(SPLIT_VALUE_TAG, m_SplitValue, inserter);
core::CPersistUtils::persist(NUMBER_SAMPLES_TAG, m_NumberSamples, inserter);
}

bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
Expand All @@ -388,6 +390,8 @@ bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& trav
core::CPersistUtils::restore(NODE_VALUE_TAG, m_NodeValue, traverser))
RESTORE(SPLIT_VALUE_TAG,
core::CPersistUtils::restore(SPLIT_VALUE_TAG, m_SplitValue, traverser))
RESTORE(NUMBER_SAMPLES_TAG,
core::CPersistUtils::restore(NUMBER_SAMPLES_TAG, m_NumberSamples, traverser))
} while (traverser.next());
return true;
}
Expand All @@ -412,8 +416,16 @@ std::ostringstream& CBoostedTreeNode::doPrint(std::string pad,
}

void CBoostedTreeNode::accept(CVisitor& visitor) const {
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft,
m_NodeValue, m_Gain, m_LeftChild, m_RightChild);
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft, m_NodeValue,
m_Gain, m_NumberSamples, m_LeftChild, m_RightChild);
}

void CBoostedTreeNode::numberSamples(std::size_t numberSamples) {
m_NumberSamples = numberSamples;
}

std::size_t CBoostedTreeNode::numberSamples() const {
return m_NumberSamples;
}

CBoostedTree::CBoostedTree(core::CDataFrame& frame,
Expand Down
66 changes: 53 additions & 13 deletions lib/maths/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
this->restoreBestHyperparameters();
std::tie(m_BestForest, std::ignore) = this->trainForest(
frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress);

m_Instrumentation->nextStep(static_cast<std::uint32_t>(m_CurrentRound));
this->recordState(recordTrainStateCallback);

Expand All @@ -233,12 +234,56 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,

this->computeProbabilityAtWhichToAssignClassOne(frame);

// populate numberSamples field in the final forest
this->computeNumberSamples(frame);

// Force progress to one because we can have early exit from loop skip altogether.
m_Instrumentation->updateProgress(1.0);
m_Instrumentation->updateMemoryUsage(
static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}

void CBoostedTreeImpl::computeNumberSamples(const core::CDataFrame& frame) {
for (auto& tree : m_BestForest) {
if (tree.size() == 1) {
root(tree).numberSamples(frame.numberRows());
} else {
auto result = frame.readRows(
m_NumberThreads,
core::bindRetrievableState(
[&](TSizeVec& samplesPerNode, const TRowItr& beginRows, const TRowItr& endRows) {
for (auto row = beginRows; row != endRows; ++row) {
auto encodedRow{m_Encoder->encode(*row)};
const CBoostedTreeNode* node{&root(tree)};
samplesPerNode[0] += 1;
std::size_t nextIndex;
while (node->isLeaf() == false) {
if (node->assignToLeft(encodedRow)) {
nextIndex = node->leftChildIndex();
} else {
nextIndex = node->rightChildIndex();
}
samplesPerNode[nextIndex] += 1;
node = &(tree[nextIndex]);
}
}
},
TSizeVec(tree.size())));
auto& state = result.first;
TSizeVec totalSamplesPerNode{std::move(state[0].s_FunctionState)};
for (std::size_t i = 1; i < state.size(); ++i) {
for (std::size_t nodeIndex = 0;
nodeIndex < totalSamplesPerNode.size(); ++nodeIndex) {
totalSamplesPerNode[nodeIndex] += state[i].s_FunctionState[nodeIndex];
}
}
for (std::size_t i = 0; i < tree.size(); ++i) {
tree[i].numberSamples(totalSamplesPerNode[i]);
}
}
}
}

void CBoostedTreeImpl::recordState(const TTrainingStateCallback& recordTrainState) const {
recordTrainState([this](core::CStatePersistInserter& inserter) {
this->acceptPersistInserter(inserter);
Expand Down Expand Up @@ -987,6 +1032,10 @@ const CBoostedTreeNode& CBoostedTreeImpl::root(const TNodeVec& tree) {
return tree[0];
}

CBoostedTreeNode& CBoostedTreeImpl::root(TNodeVec& tree) {
return tree[0];
}

double CBoostedTreeImpl::predictRow(const CEncodedDataFrameRowRef& row,
const TNodeVecVec& forest) {
double result{0.0};
Expand Down Expand Up @@ -1138,9 +1187,8 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const {
}

namespace {
const std::string VERSION_7_5_TAG{"7.5"};
const std::string VERSION_7_6_TAG{"7.6"};
const TStrVec SUPPORTED_VERSIONS{VERSION_7_5_TAG, VERSION_7_6_TAG};
const std::string VERSION_7_7_TAG{"7.7"};
const TStrVec SUPPORTED_VERSIONS{VERSION_7_7_TAG};

const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"};
const std::string BEST_FOREST_TAG{"best_forest"};
Expand Down Expand Up @@ -1204,7 +1252,7 @@ CBoostedTreeImpl::TStrVec CBoostedTreeImpl::bestHyperparameterNames() {
}

void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
core::CPersistUtils::persist(VERSION_7_6_TAG, "", inserter);
core::CPersistUtils::persist(VERSION_7_7_TAG, "", inserter);
core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter);
core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter);
core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter);
Expand Down Expand Up @@ -1256,15 +1304,7 @@ void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& insert
}

bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
if (traverser.name() == VERSION_7_5_TAG) {
// Force downsample factor to 1.0.
m_DownsampleFactorOverride = 1.0;
m_DownsampleFactor = 1.0;
m_BestHyperparameters.downsampleFactor(1.0);
// We can't stop cross-validation early because we haven't gathered the
// per fold test losses.
m_StopCrossValidationEarly = false;
} else if (traverser.name() != VERSION_7_6_TAG) {
if (traverser.name() != VERSION_7_7_TAG) {
LOG_ERROR(<< "Input error: unsupported state serialization version. "
<< "Currently supported versions: "
<< core::CContainerPrinter::print(SUPPORTED_VERSIONS) << ".");
Expand Down
Loading