Skip to content

[7.7][ML] Add information about samples per node to the tree #1006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ progress, memory usage, etc. (See {ml-pull}906[#906].)

* Improve initialization of learn rate for better and more stable results in regression
and classification. (See {ml-pull}948[#948].)
* Add number of processed training samples to the definition of decision tree nodes.
(See {ml-pull}991[#991].)
* Add new model_size_stats fields to instrument categorization. (See {ml-pull}948[#948]
and {pull}51879[#51879], issue: {issue}50794[#50749].)

Expand Down
1 change: 1 addition & 0 deletions include/api/CBoostedTreeInferenceModelBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class API_EXPORT CBoostedTreeInferenceModelBuilder : public maths::CBoostedTree:
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) override;
void addIdentityEncoding(std::size_t inputColumnIndex) override;
Expand Down
4 changes: 3 additions & 1 deletion include/api/CInferenceModelDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class API_EXPORT CTree final : public CTrainedModel {
double threshold,
bool defaultLeft,
double leafValue,
size_t splitFeature,
std::size_t splitFeature,
std::size_t numberSamples,
const TOptionalNodeIndex& leftChild,
const TOptionalNodeIndex& rightChild,
const TOptionalDouble& splitGain);
Expand All @@ -175,6 +176,7 @@ class API_EXPORT CTree final : public CTrainedModel {
TOptionalNodeIndex m_LeftChild;
TOptionalNodeIndex m_RightChild;
std::size_t m_SplitFeature;
std::size_t m_NumberSamples;
double m_Threshold;
double m_LeafValue;
TOptionalDouble m_SplitGain;
Expand Down
8 changes: 8 additions & 0 deletions include/maths/CBoostedTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
TOptionalNodeIndex leftChild,
TOptionalNodeIndex rightChild) = 0;
};
Expand Down Expand Up @@ -306,6 +307,12 @@ class MATHS_EXPORT CBoostedTreeNode final {
//! Get the total curvature at the rows below this node.
double curvature() const { return m_Curvature; }

//! Set the number of samples to \p value.
void numberSamples(std::size_t value);

//! Get number of samples affected by the node.
std::size_t numberSamples() const;

//! Get the index of the left child node.
TNodeIndex leftChildIndex() const { return m_LeftChild.get(); }

Expand Down Expand Up @@ -348,6 +355,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
double m_NodeValue = 0.0;
double m_Gain = 0.0;
double m_Curvature = 0.0;
std::size_t m_NumberSamples = 0;
};

//! \brief A boosted regression tree model.
Expand Down
6 changes: 6 additions & 0 deletions include/maths/CBoostedTreeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! Get the root node of \p tree.
static const CBoostedTreeNode& root(const TNodeVec& tree);

//! Get the root node of \p tree.
static CBoostedTreeNode& root(TNodeVec& tree);

//! Get the forest's prediction for \p row.
static double predictRow(const CEncodedDataFrameRowRef& row, const TNodeVecVec& forest);

Expand Down Expand Up @@ -287,6 +290,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! Record the training state using the \p recordTrainState callback function
void recordState(const TTrainingStateCallback& recordTrainState) const;

//! Populate numberSamples field in the m_BestForest
void computeNumberSamples(const core::CDataFrame& frame);

private:
mutable CPRNG::CXorOShiro128Plus m_Rng;
std::size_t m_NumberThreads;
Expand Down
13 changes: 1 addition & 12 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! by \p offset.
void shap(core::CDataFrame& frame, const CDataFrameCategoryEncoder& encoder, std::size_t offset);

//! Compute number of training samples from \p frame that pass every node in the \p tree.
static TDoubleVec samplesPerNode(const TTree& tree,
const core::CDataFrame& frame,
const CDataFrameCategoryEncoder& encoder,
std::size_t numThreads);

//! Recursively computes inner node values as weighted average of the children (leaf) values
//! \returns The maximum depth the the tree.
static std::size_t updateNodeValues(TTree& tree,
std::size_t nodeIndex,
const TDoubleVec& samplesPerNode,
std::size_t depth);
static size_t updateNodeValues(TTree& tree, std::size_t nodeIndex, std::size_t depth);

//! Get the reference to the trees.
TTreeVec& trees() { return m_Trees; }
Expand Down Expand Up @@ -126,7 +117,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! Recursively traverses all pathes in the \p tree and updated SHAP values once it hits a leaf.
//! Ref. Algorithm 2 in the paper by Lundberg et al.
void shapRecursive(const TTree& tree,
const TDoubleVec& samplesPerNode,
const CDataFrameCategoryEncoder& encoder,
const CEncodedDataFrameRowRef& encodedRow,
SPath& splitPath,
Expand All @@ -146,7 +136,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
private:
TTreeVec m_Trees;
std::size_t m_NumberThreads;
TDoubleVecVec m_SamplesPerNode;
};
}
}
Expand Down
6 changes: 4 additions & 2 deletions lib/api/CBoostedTreeInferenceModelBuilder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) {
auto ensemble{static_cast<CEnsemble*>(m_Definition.trainedModel().get())};
Expand All @@ -97,8 +98,9 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
if (tree == nullptr) {
HANDLE_FATAL(<< "Internal error. Tree points to a nullptr.")
}
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft, nodeValue,
splitFeature, leftChild, rightChild, gain);
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft,
nodeValue, splitFeature, numberSamples,
leftChild, rightChild, gain);
}

CBoostedTreeInferenceModelBuilder::CBoostedTreeInferenceModelBuilder(TStrVec fieldNames,
Expand Down
10 changes: 8 additions & 2 deletions lib/api/CInferenceModelDefinition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const std::string JSON_LEFT_CHILD_TAG{"left_child"};
const std::string JSON_LOGISTIC_REGRESSION_TAG{"logistic_regression"};
const std::string JSON_LT{"lt"};
const std::string JSON_NODE_INDEX_TAG{"node_index"};
const std::string JSON_NUMBER_SAMPLES_TAG{"number_samples"};
const std::string JSON_ONE_HOT_ENCODING_TAG{"one_hot_encoding"};
const std::string JSON_PREPROCESSORS_TAG{"preprocessors"};
const std::string JSON_RIGHT_CHILD_TAG{"right_child"};
Expand Down Expand Up @@ -79,6 +80,9 @@ void addJsonArray(const std::string& tag,
void CTree::CTreeNode::addToDocument(rapidjson::Value& parentObject,
TRapidJsonWriter& writer) const {
writer.addMember(JSON_NODE_INDEX_TAG, rapidjson::Value(m_NodeIndex).Move(), parentObject);
writer.addMember(
JSON_NUMBER_SAMPLES_TAG,
rapidjson::Value(static_cast<std::uint64_t>(m_NumberSamples)).Move(), parentObject);

if (m_LeftChild) {
// internal node
Expand Down Expand Up @@ -118,11 +122,13 @@ CTree::CTreeNode::CTreeNode(TNodeIndex nodeIndex,
bool defaultLeft,
double leafValue,
std::size_t splitFeature,
std::size_t numberSamples,
const TOptionalNodeIndex& leftChild,
const TOptionalNodeIndex& rightChild,
const TOptionalDouble& splitGain)
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex), m_LeftChild(leftChild),
m_RightChild(rightChild), m_SplitFeature(splitFeature),
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex),
m_LeftChild(leftChild), m_RightChild(rightChild),
m_SplitFeature(splitFeature), m_NumberSamples(numberSamples),
m_Threshold(threshold), m_LeafValue(leafValue), m_SplitGain(splitGain) {
}

Expand Down
5 changes: 3 additions & 2 deletions lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,9 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) {
// c1 explains 95% of the prediction value, i.e. the difference from the prediction is less than 2%.
BOOST_REQUIRE_CLOSE(c1, prediction, 5.0);
for (const auto& feature : {"c2", "c3", "c4"}) {
BOOST_REQUIRE_SMALL(readShapValue(result, feature), 2.0);
cNoImportanceMean.add(std::fabs(readShapValue(result, feature)));
double c = readShapValue(result, feature);
BOOST_REQUIRE_SMALL(c, 2.0);
cNoImportanceMean.add(std::fabs(c));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
},
"right_child": {
"type": "integer"
},
"number_samples": {
"description": "Number of training samples that were affected by the node.",
"type": "integer"
}
},
"required": [
Expand All @@ -75,7 +79,8 @@
"decision_type",
"default_left",
"left_child",
"right_child"
"right_child",
"number_samples"
],
"additionalProperties": false
},
Expand All @@ -88,11 +93,16 @@
},
"leaf_value": {
"type": "number"
},
"number_samples": {
"description": "Number of training samples that were affected by the node.",
"type": "integer"
}
},
"required": [
"node_index",
"leaf_value"
"leaf_value",
"number_samples"
],
"additionalProperties": false
},
Expand Down Expand Up @@ -234,10 +244,14 @@
"items": {
"type": "number"
}
},
"num_classes": {
"type": "integer"
}
},
"required": [
"weights"
"weights",
"num_classes"
]
}
},
Expand Down
16 changes: 14 additions & 2 deletions lib/maths/CBoostedTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const std::string SPLIT_FEATURE_TAG{"split_feature"};
const std::string ASSIGN_MISSING_TO_LEFT_TAG{"assign_missing_to_left "};
const std::string NODE_VALUE_TAG{"node_value"};
const std::string SPLIT_VALUE_TAG{"split_value"};
const std::string NUMBER_SAMPLES_TAG{"number_samples"};

double LOG_EPSILON{std::log(100.0 * std::numeric_limits<double>::epsilon())};

Expand Down Expand Up @@ -370,6 +371,7 @@ void CBoostedTreeNode::acceptPersistInserter(core::CStatePersistInserter& insert
core::CPersistUtils::persist(ASSIGN_MISSING_TO_LEFT_TAG, m_AssignMissingToLeft, inserter);
core::CPersistUtils::persist(NODE_VALUE_TAG, m_NodeValue, inserter);
core::CPersistUtils::persist(SPLIT_VALUE_TAG, m_SplitValue, inserter);
core::CPersistUtils::persist(NUMBER_SAMPLES_TAG, m_NumberSamples, inserter);
}

bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
Expand All @@ -388,6 +390,8 @@ bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& trav
core::CPersistUtils::restore(NODE_VALUE_TAG, m_NodeValue, traverser))
RESTORE(SPLIT_VALUE_TAG,
core::CPersistUtils::restore(SPLIT_VALUE_TAG, m_SplitValue, traverser))
RESTORE(NUMBER_SAMPLES_TAG,
core::CPersistUtils::restore(NUMBER_SAMPLES_TAG, m_NumberSamples, traverser))
} while (traverser.next());
return true;
}
Expand All @@ -412,8 +416,16 @@ std::ostringstream& CBoostedTreeNode::doPrint(std::string pad,
}

void CBoostedTreeNode::accept(CVisitor& visitor) const {
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft,
m_NodeValue, m_Gain, m_LeftChild, m_RightChild);
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft, m_NodeValue,
m_Gain, m_NumberSamples, m_LeftChild, m_RightChild);
}

void CBoostedTreeNode::numberSamples(std::size_t numberSamples) {
m_NumberSamples = numberSamples;
}

std::size_t CBoostedTreeNode::numberSamples() const {
return m_NumberSamples;
}

CBoostedTree::CBoostedTree(core::CDataFrame& frame,
Expand Down
66 changes: 53 additions & 13 deletions lib/maths/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
this->restoreBestHyperparameters();
std::tie(m_BestForest, std::ignore) = this->trainForest(
frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress);

m_Instrumentation->nextStep(static_cast<std::uint32_t>(m_CurrentRound));
this->recordState(recordTrainStateCallback);

Expand All @@ -233,12 +234,56 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,

this->computeProbabilityAtWhichToAssignClassOne(frame);

// populate numberSamples field in the final forest
this->computeNumberSamples(frame);

// Force progress to one because we can have early exit from loop skip altogether.
m_Instrumentation->updateProgress(1.0);
m_Instrumentation->updateMemoryUsage(
static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}

void CBoostedTreeImpl::computeNumberSamples(const core::CDataFrame& frame) {
for (auto& tree : m_BestForest) {
if (tree.size() == 1) {
root(tree).numberSamples(frame.numberRows());
} else {
auto result = frame.readRows(
m_NumberThreads,
core::bindRetrievableState(
[&](TSizeVec& samplesPerNode, const TRowItr& beginRows, const TRowItr& endRows) {
for (auto row = beginRows; row != endRows; ++row) {
auto encodedRow{m_Encoder->encode(*row)};
const CBoostedTreeNode* node{&root(tree)};
samplesPerNode[0] += 1;
std::size_t nextIndex;
while (node->isLeaf() == false) {
if (node->assignToLeft(encodedRow)) {
nextIndex = node->leftChildIndex();
} else {
nextIndex = node->rightChildIndex();
}
samplesPerNode[nextIndex] += 1;
node = &(tree[nextIndex]);
}
}
},
TSizeVec(tree.size())));
auto& state = result.first;
TSizeVec totalSamplesPerNode{std::move(state[0].s_FunctionState)};
for (std::size_t i = 1; i < state.size(); ++i) {
for (std::size_t nodeIndex = 0;
nodeIndex < totalSamplesPerNode.size(); ++nodeIndex) {
totalSamplesPerNode[nodeIndex] += state[i].s_FunctionState[nodeIndex];
}
}
for (std::size_t i = 0; i < tree.size(); ++i) {
tree[i].numberSamples(totalSamplesPerNode[i]);
}
}
}
}

void CBoostedTreeImpl::recordState(const TTrainingStateCallback& recordTrainState) const {
recordTrainState([this](core::CStatePersistInserter& inserter) {
this->acceptPersistInserter(inserter);
Expand Down Expand Up @@ -987,6 +1032,10 @@ const CBoostedTreeNode& CBoostedTreeImpl::root(const TNodeVec& tree) {
return tree[0];
}

CBoostedTreeNode& CBoostedTreeImpl::root(TNodeVec& tree) {
return tree[0];
}

double CBoostedTreeImpl::predictRow(const CEncodedDataFrameRowRef& row,
const TNodeVecVec& forest) {
double result{0.0};
Expand Down Expand Up @@ -1138,9 +1187,8 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const {
}

namespace {
const std::string VERSION_7_5_TAG{"7.5"};
const std::string VERSION_7_6_TAG{"7.6"};
const TStrVec SUPPORTED_VERSIONS{VERSION_7_5_TAG, VERSION_7_6_TAG};
const std::string VERSION_7_7_TAG{"7.7"};
const TStrVec SUPPORTED_VERSIONS{VERSION_7_7_TAG};

const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"};
const std::string BEST_FOREST_TAG{"best_forest"};
Expand Down Expand Up @@ -1204,7 +1252,7 @@ CBoostedTreeImpl::TStrVec CBoostedTreeImpl::bestHyperparameterNames() {
}

void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
core::CPersistUtils::persist(VERSION_7_6_TAG, "", inserter);
core::CPersistUtils::persist(VERSION_7_7_TAG, "", inserter);
core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter);
core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter);
core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter);
Expand Down Expand Up @@ -1256,15 +1304,7 @@ void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& insert
}

bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
if (traverser.name() == VERSION_7_5_TAG) {
// Force downsample factor to 1.0.
m_DownsampleFactorOverride = 1.0;
m_DownsampleFactor = 1.0;
m_BestHyperparameters.downsampleFactor(1.0);
// We can't stop cross-validation early because we haven't gathered the
// per fold test losses.
m_StopCrossValidationEarly = false;
} else if (traverser.name() != VERSION_7_6_TAG) {
if (traverser.name() != VERSION_7_7_TAG) {
LOG_ERROR(<< "Input error: unsupported state serialization version. "
<< "Currently supported versions: "
<< core::CContainerPrinter::print(SUPPORTED_VERSIONS) << ".");
Expand Down
Loading