Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature to print forest shape in FIL upon importing #3763

Merged
merged 30 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1245d00
try 1
levsnv Apr 17, 2021
c81eb32
fixes
levsnv Apr 18, 2021
e8433c1
fixed
levsnv Apr 18, 2021
2759de3
Apply suggestions from code review
levsnv Apr 27, 2021
e21f0e9
addressed review comments
levsnv Apr 27, 2021
3f4a5d6
refactor
levsnv Apr 28, 2021
9e5c5a4
style
levsnv Apr 28, 2021
8ae9e5f
style
levsnv Apr 28, 2021
fe12e6c
Merge remote-tracking branch 'origin/refactor-cython-kwargs' into pri…
levsnv Apr 28, 2021
5f85f27
Merge branch 'branch-0.20' into print-model-shape
levsnv Apr 28, 2021
31f9f46
fix memory leaks; confusing naming
levsnv Apr 28, 2021
22abfd4
stop sprawling code duplication
levsnv Apr 28, 2021
456ca96
Merge branch 'print-model-shape' of github.com:levsnv/cuml into print…
levsnv Apr 28, 2021
6fa42ca
Merge remote-tracking branch 'rapidsai/branch-0.20' into print-model-…
levsnv Apr 29, 2021
b07e8d7
Merge branch 'branch-0.20' of github.com:rapidsai/cuml into print-mod…
levsnv Apr 30, 2021
b42a3d5
Apply suggestions from code review
levsnv May 6, 2021
d1d6917
addressed review comments except one
levsnv May 11, 2021
e719e7b
fixed FNV to standard byte-based, added tests
levsnv May 11, 2021
93d0c31
style
levsnv May 11, 2021
64423b8
Merge branch 'branch-0.20' of github.com:rapidsai/cuml into print-mod…
levsnv May 11, 2021
a9f7ac9
fixed copyright issues, excluded trivial implementation of public code
levsnv May 11, 2021
d3c24f9
style, copyright
levsnv May 19, 2021
f5f825f
addressed review comments
levsnv May 19, 2021
47d5911
added back " MB" suffix
levsnv May 19, 2021
6b41953
fixed insufficient precision
levsnv May 19, 2021
739940b
fix extra <<
levsnv May 20, 2021
7e6302f
switched from forest_shape_file to compute_shape_str and shape_str
levsnv May 21, 2021
dcce65a
copyright.year
levsnv May 21, 2021
2a07441
style, documentation
levsnv May 21, 2021
0d2a420
more verbose error message
levsnv May 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ class FIL : public RegressionFixture<float> {
.algo = p_rest.algo,
.output_class = params.nclasses > 1, // cuML RF forest
.threshold = 1.f / params.nclasses, //Fixture::DatasetParams
.storage_type = p_rest.storage};
.storage_type = p_rest.storage,
.blocks_per_sm = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are .blocks_per_sm doing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, the compiler was complaining about non-standard initializers, which is a gcc extension before C++20: https://gcc.gnu.org/onlinedocs/gcc/Designated-Inits.html
I had forgotten to add the .blocks_per_sm=0 in the relevant PR, but it must have gotten default-initialized to 0 and worked well enough to miss the issue.
Now, with the trailing member added, it may be triggering the gcc limitation on out-of-order designated initializers. If I also omit .pforest_shape_str=nullptr, I am setting us up for even more pain down the road, even if it works.
Maybe there's a different workaround, but wouldn't this code be the most desirable one in the end?

.pforest_shape_str = nullptr};
ML::fil::from_treelite(*handle, &forest, model, &tl_params);

// only time prediction
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ enum storage_type_t {
whether a particular forest can be imported as SPARSE8 */
SPARSE8,
};
static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"};

struct forest;

Expand All @@ -89,6 +90,9 @@ struct treelite_params_t {
// suggested values (if nonzero) are from 2 to 7
// if zero, launches ceildiv(num_rows, NITEMS) blocks
int blocks_per_sm;
// if non-nullptr, *pforest_shape_str will be set to caller-owned string that
// contains forest shape
char** pforest_shape_str;
};

/** from_treelite uses a treelite model to initialize the forest
Expand Down
41 changes: 41 additions & 0 deletions cpp/include/cuml/fil/fnv_hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <limits.h>
#include <cstdint>
#include <numeric>

// Implements https://tools.ietf.org/html/draft-eastlake-fnv-17.html
// Algorithm is public domain, non-cryptographic strength and no patents or rights to patent.
// If input elements are not 8-bit, such a computation does not match
// the FNV spec.
template <typename It>
unsigned long long fowler_noll_vo_fingerprint64(It begin, It end) {
static_assert(sizeof(*begin) == 1,
"FNV deals with byte-sized (octet) input arrays only");
return std::accumulate(begin, end, 14695981039346656037ull,
[](const unsigned long long& fingerprint, auto x) {
return (fingerprint * 0x100000001b3ull) ^ x;
});
}

// xor-folded fingerprint64 to ensure first bits are affected by other input bits
// should give a 1% collision probability within a 10'000 hash set
template <typename It>
uint32_t fowler_noll_vo_fingerprint64_32(It begin, It end) {
unsigned long long fp64 = fowler_noll_vo_fingerprint64(begin, end);
return (fp64 & UINT_MAX) ^ (fp64 >> 32);
}
155 changes: 136 additions & 19 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
#include <treelite/tree.h>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <limits>
#include <stack>
#include <utility>

#include <cuml/fil/fil.h>
#include <cuml/fil/fnv_hash.h>
#include <raft/cudart_utils.h>
#include <cuml/common/logger.hpp>
#include <raft/handle.hpp>
Expand Down Expand Up @@ -636,6 +638,85 @@ int tree2fil_sparse(std::vector<fil_node_t>& nodes, int root,
return root;
}

struct level_entry {
int n_branch_nodes, n_leaves;
};
typedef std::pair<int, int> pair_t;
// hist has branch and leaf count given depth
template <typename T, typename L>
inline void tree_depth_hist(const tl::Tree<T, L>& tree,
std::vector<level_entry>& hist) {
std::stack<pair_t> stack; // {tl_id, depth}
stack.push({tree_root(tree), 0});
while (!stack.empty()) {
const pair_t& top = stack.top();
int node_id = top.first;
int depth = top.second;
stack.pop();

while (!tree.IsLeaf(node_id)) {
if (depth >= hist.size()) hist.resize(depth + 1, {0, 0});
hist[depth].n_branch_nodes++;
stack.push({tree.LeftChild(node_id), depth + 1});
node_id = tree.RightChild(node_id);
depth++;
}

if (depth >= hist.size()) hist.resize(depth + 1, {0, 0});
hist[depth].n_leaves++;
}
}

template <typename T, typename L>
std::stringstream depth_hist_and_max(const tl::ModelImpl<T, L>& model) {
using namespace std;
vector<level_entry> hist;
for (const auto& tree : model.trees) tree_depth_hist(tree, hist);

int min_leaf_depth = -1, leaves_times_depth = 0, total_branches = 0,
total_leaves = 0;
stringstream forest_shape;
ios default_state(nullptr);
default_state.copyfmt(forest_shape);
forest_shape << "Depth histogram:" << endl
<< "depth branches leaves nodes" << endl;
for (int level = 0; level < hist.size(); ++level) {
level_entry e = hist[level];
forest_shape << setw(5) << level << setw(9) << e.n_branch_nodes << setw(7)
<< e.n_leaves << setw(8) << e.n_branch_nodes + e.n_leaves
<< endl;
forest_shape.copyfmt(default_state);
if (e.n_leaves && min_leaf_depth == -1) min_leaf_depth = level;
leaves_times_depth += e.n_leaves * level;
total_branches += e.n_branch_nodes;
total_leaves += e.n_leaves;
}
int total_nodes = total_branches + total_leaves;
forest_shape << "Total: branches: " << total_branches
<< " leaves: " << total_leaves << " nodes: " << total_nodes
<< endl;
forest_shape << "Avg nodes per tree: " << setprecision(2)
<< total_nodes / (float)hist[0].n_branch_nodes << endl;
forest_shape.copyfmt(default_state);
forest_shape << "Leaf depth: min: " << min_leaf_depth
<< " avg: " << setprecision(2) << fixed
<< leaves_times_depth / (float)total_leaves
<< " max: " << hist.size() - 1 << endl;
forest_shape.copyfmt(default_state);

vector<char> hist_bytes(hist.size() * sizeof(hist[0]));
memcpy(&hist_bytes[0], &hist[0], hist_bytes.size());
// std::hash does not promise to not be identity. Xoring plain numbers which
// add up to one another erases information, hence, std::hash is unsuitable here
forest_shape << "Depth histogram fingerprint: " << hex
<< fowler_noll_vo_fingerprint64_32(hist_bytes.begin(),
hist_bytes.end())
<< endl;
forest_shape.copyfmt(default_state);

return forest_shape;
}

template <typename T, typename L>
size_t tl_leaf_vector_size(const tl::ModelImpl<T, L>& model) {
const tl::Tree<T, L>& tree = model.trees[0];
Expand Down Expand Up @@ -729,9 +810,9 @@ void tl2fil_common(forest_params_t* params, const tl::ModelImpl<T, L>& model,

// uses treelite model with additional tl_params to initialize FIL params
// and dense nodes (stored in *pnodes)
template <typename T, typename L>
template <typename threshold_t, typename leaf_t>
void tl2fil_dense(std::vector<dense_node>* pnodes, forest_params_t* params,
const tl::ModelImpl<T, L>& model,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params) {
tl2fil_common(params, model, tl_params);

Expand All @@ -746,8 +827,8 @@ void tl2fil_dense(std::vector<dense_node>* pnodes, forest_params_t* params,

template <typename fil_node_t>
struct tl2fil_sparse_check_t {
template <typename T, typename L>
static void check(const tl::ModelImpl<T, L>& model) {
template <typename threshold_t, typename leaf_t>
static void check(const tl::ModelImpl<threshold_t, leaf_t>& model) {
ASSERT(false,
"internal error: "
"only a specialization of this template should be used");
Expand All @@ -757,16 +838,16 @@ struct tl2fil_sparse_check_t {
template <>
struct tl2fil_sparse_check_t<sparse_node16> {
// no extra check for 16-byte sparse nodes
template <typename T, typename L>
static void check(const tl::ModelImpl<T, L>& model) {}
template <typename threshold_t, typename leaf_t>
static void check(const tl::ModelImpl<threshold_t, leaf_t>& model) {}
};

template <>
struct tl2fil_sparse_check_t<sparse_node8> {
static const int MAX_FEATURES = 1 << sparse_node8::FID_NUM_BITS;
static const int MAX_TREE_NODES = (1 << sparse_node8::LEFT_NUM_BITS) - 1;
template <typename T, typename L>
static void check(const tl::ModelImpl<T, L>& model) {
template <typename threshold_t, typename leaf_t>
static void check(const tl::ModelImpl<threshold_t, leaf_t>& model) {
// check the number of features
int num_features = model.num_feature;
ASSERT(num_features <= MAX_FEATURES,
Expand All @@ -775,7 +856,7 @@ struct tl2fil_sparse_check_t<sparse_node8> {
num_features, MAX_FEATURES);

// check the number of tree nodes
const std::vector<tl::Tree<T, L>>& trees = model.trees;
const std::vector<tl::Tree<threshold_t, leaf_t>>& trees = model.trees;
for (int i = 0; i < trees.size(); ++i) {
int num_nodes = trees[i].num_nodes;
ASSERT(num_nodes <= MAX_TREE_NODES,
Expand All @@ -788,9 +869,10 @@ struct tl2fil_sparse_check_t<sparse_node8> {

// uses treelite model with additional tl_params to initialize FIL params,
// trees (stored in *ptrees) and sparse nodes (stored in *pnodes)
template <typename fil_node_t, typename T, typename L>
template <typename fil_node_t, typename threshold_t, typename leaf_t>
void tl2fil_sparse(std::vector<int>* ptrees, std::vector<fil_node_t>* pnodes,
forest_params_t* params, const tl::ModelImpl<T, L>& model,
forest_params_t* params,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params) {
tl2fil_common(params, model, tl_params);
tl2fil_sparse_check_t<fil_node_t>::check(model);
Expand Down Expand Up @@ -843,18 +925,21 @@ template void init_sparse<sparse_node8>(const raft::handle_t& h, forest_t* pf,
const sparse_node8* nodes,
const forest_params_t* params);

template <typename T, typename L>
template <typename threshold_t, typename leaf_t>
void from_treelite(const raft::handle_t& handle, forest_t* pforest,
const tl::ModelImpl<T, L>& model,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params) {
// Invariants on threshold and leaf types
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
static_assert(std::is_same<threshold_t, float>::value ||
std::is_same<threshold_t, double>::value,
"Model must contain float32 or float64 thresholds for splits");
ASSERT((std::is_same<L, float>::value || std::is_same<L, double>::value),
"Models with integer leaf output are not yet supported");
ASSERT(
(std::is_same<leaf_t, float>::value || std::is_same<leaf_t, double>::value),
"Models with integer leaf output are not yet supported");
// Display appropriate warnings when float64 values are being casted into
// float32, as FIL only supports inferencing with float32 for the time being
if (std::is_same<T, double>::value || std::is_same<L, double>::value) {
if (std::is_same<threshold_t, double>::value ||
std::is_same<leaf_t, double>::value) {
CUML_LOG_WARN(
"Casting all thresholds and leaf values to float32, as FIL currently "
"doesn't support inferencing models with float64 values. "
Expand Down Expand Up @@ -889,6 +974,10 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest,
// sync is necessary as nodes is used in init_dense(),
// but destructed at the end of this function
CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));
if (tl_params->pforest_shape_str) {
*tl_params->pforest_shape_str =
sprintf_shape(model, storage_type, nodes, {});
}
break;
}
case storage_type_t::SPARSE: {
Expand All @@ -897,6 +986,10 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest,
tl2fil_sparse(&trees, &nodes, &params, model, tl_params);
init_sparse(handle, pforest, trees.data(), nodes.data(), &params);
CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));
if (tl_params->pforest_shape_str) {
*tl_params->pforest_shape_str =
sprintf_shape(model, storage_type, nodes, trees);
}
break;
}
case storage_type_t::SPARSE8: {
Expand All @@ -905,6 +998,10 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest,
tl2fil_sparse(&trees, &nodes, &params, model, tl_params);
init_sparse(handle, pforest, trees.data(), nodes.data(), &params);
CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));
if (tl_params->pforest_shape_str) {
*tl_params->pforest_shape_str =
sprintf_shape(model, storage_type, nodes, trees);
}
break;
}
default:
Expand All @@ -915,12 +1012,32 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest,
void from_treelite(const raft::handle_t& handle, forest_t* pforest,
ModelHandle model, const treelite_params_t* tl_params) {
const tl::Model& model_ref = *(tl::Model*)model;
model_ref.Dispatch([&handle, pforest, tl_params](const auto& model_inner) {
// model_inner is of the concrete type tl::ModelImpl<T, L>
model_ref.Dispatch([&](const auto& model_inner) {
// model_inner is of the concrete type tl::ModelImpl<threshold_t, leaf_t>
from_treelite(handle, pforest, model_inner, tl_params);
});
}

// allocates caller-owned char* using malloc()
template <typename threshold_t, typename leaf_t, typename node_t>
char* sprintf_shape(const tl::ModelImpl<threshold_t, leaf_t>& model,
storage_type_t storage, const std::vector<node_t>& nodes,
const std::vector<int>& trees) {
std::stringstream forest_shape = depth_hist_and_max(model);
float size_mb = (trees.size() * sizeof(trees.front()) +
nodes.size() * sizeof(nodes.front())) /
1e6;
forest_shape << storage_type_repr[storage] << " model size "
<< std::setprecision(2) << size_mb << " MB" << std::endl;
// stream may be discontiguous
std::string forest_shape_str = forest_shape.str();
// now copy to a non-owning allocation
char* shape_out = (char*)malloc(forest_shape_str.size() + 1); // incl. \0
memcpy((void*)shape_out, forest_shape_str.c_str(),
forest_shape_str.size() + 1);
return shape_out;
}

void free(const raft::handle_t& h, forest_t f) {
f->free(h);
delete f;
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ if(BUILD_CUML_TESTS)
sg/decisiontree_batchedlevel_algo.cu
sg/decisiontree_batchedlevel_unittest.cu
sg/fil_test.cu
sg/fnv_hash_test.cpp
sg/genetic/node_test.cpp
sg/genetic/param_test.cu
sg/handle_test.cu
Expand Down
15 changes: 15 additions & 0 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct FilTestParams {
algo_t algo = algo_t::NAIVE;
int seed = 42;
float tolerance = 2e-3f;
bool print_forest_shape = false;
// treelite parameters, only used for treelite tests
tl::Operator op = tl::Operator::kLT;
leaf_algo_t leaf_algo = leaf_algo_t::FLOAT_UNARY_BINARY;
Expand Down Expand Up @@ -612,8 +613,21 @@ class TreeliteFilTest : public BaseFilTest {
params.output_class = (ps.output & fil::output_t::CLASS) != 0;
params.storage_type = storage_type;
params.blocks_per_sm = ps.blocks_per_sm;
char* forest_shape_str = nullptr;
params.pforest_shape_str =
ps.print_forest_shape ? &forest_shape_str : nullptr;
fil::from_treelite(handle, pforest, (ModelHandle)model.get(), &params);
CUDA_CHECK(cudaStreamSynchronize(stream));
if (ps.print_forest_shape) {
std::string str(forest_shape_str);
for (const char* substr :
{"model size", " MB", "Depth histogram:", "Avg nodes per tree",
"Leaf depth", "Depth histogram fingerprint"}) {
ASSERT(str.find(substr) != std::string::npos,
"\"%s\" not found in forest shape :\n%s", substr, str.c_str());
}
}
::free(forest_shape_str);
}
};

Expand Down Expand Up @@ -861,6 +875,7 @@ std::vector<FilTestParams> import_dense_inputs = {
leaf_algo = GROVE_PER_CLASS, num_classes = 7),
FIL_TEST_PARAMS(num_trees = 48, output = CLASS, leaf_algo = GROVE_PER_CLASS,
num_classes = 6),
FIL_TEST_PARAMS(print_forest_shape = true),
};

TEST_P(TreeliteDenseFilTest, Import) { compare(); }
Expand Down
Loading