Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TreeExplainer extensions #4697

Merged
merged 28 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8e891ff
Fix cast to incorrect model type.
RAMitchell Mar 24, 2022
e7ee0dc
Hypothesis tests
RAMitchell Mar 30, 2022
6d72d1c
Lint
RAMitchell Mar 30, 2022
dab35e0
Merge branch 'branch-22.04' of https://github.com/rapidsai/cuml into …
RAMitchell Mar 31, 2022
6d023e1
Refactor
RAMitchell Mar 31, 2022
aa762dc
64 bit support
RAMitchell Apr 4, 2022
a4d1fb2
Abstract class version.
RAMitchell Apr 5, 2022
8efb53c
Void ptr
RAMitchell Apr 6, 2022
9454395
Variant shared_ptr
RAMitchell Apr 6, 2022
334020c
Type erase data
RAMitchell Apr 6, 2022
a288a67
Merge branch 'branch-22.06' of https://github.com/rapidsai/cuml into …
RAMitchell Apr 6, 2022
030b67a
Disable HealthCheck, reduce examples
RAMitchell Apr 7, 2022
4f67989
Merge branch 'branch-22.06' of https://github.com/rapidsai/cuml into …
RAMitchell Apr 7, 2022
5064739
Interactions
RAMitchell Apr 11, 2022
ab90cee
Merge branch 'branch-22.06' of https://github.com/rapidsai/cuml into …
RAMitchell Apr 11, 2022
82f42ef
Docs
RAMitchell Apr 12, 2022
b3408f4
Flake8
RAMitchell Apr 12, 2022
e037db1
Merge branch 'branch-22.06' of https://github.com/rapidsai/cuml into …
RAMitchell Apr 12, 2022
a0086f3
Add virtual destructor
RAMitchell Apr 13, 2022
18aa935
Directly support xgb/lgbm sklearn model objects
RAMitchell Apr 13, 2022
38969ee
Promote from experimental
RAMitchell Apr 13, 2022
1382077
Merge branch 'treeshap-test' of github.com:RAMitchell/cuml into trees…
RAMitchell Apr 13, 2022
565bd4e
Merge
RAMitchell Apr 18, 2022
e67222e
Update docs
RAMitchell Apr 19, 2022
3a98a8b
Review comments
RAMitchell Apr 21, 2022
a110f57
Merge branch 'branch-22.06' of https://github.com/rapidsai/cuml into …
RAMitchell Apr 25, 2022
40285ea
Tests
RAMitchell Apr 25, 2022
c47b354
Disable xgboost tests in cuml
RAMitchell May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions cpp/include/cuml/explainer/tree_shap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,51 @@
#include <cstdint>
#include <cuml/ensemble/treelite_defs.hpp>
#include <memory>
#include <variant>

namespace ML {
namespace Explainer {

// An abstract class representing an opaque handle to path information
// extracted from a tree model. The implementation in tree_shap.cu will
// define an internal class that inherits from this abtract class.
class TreePathInfo {
public:
enum class ThresholdTypeEnum : std::uint8_t { kFloat, kDouble };
virtual ThresholdTypeEnum GetThresholdType() const = 0;
virtual ~TreePathInfo() {}
};

std::unique_ptr<TreePathInfo> extract_path_info(ModelHandle model);
void gpu_treeshap(TreePathInfo* path_info,
const float* data,
template <typename T>
class TreePathInfo;

using TreePathHandle =
std::variant<std::shared_ptr<TreePathInfo<float>>, std::shared_ptr<TreePathInfo<double>>>;

using FloatPointer = std::variant<float*, double*>;

TreePathHandle extract_path_info(ModelHandle model);

void gpu_treeshap(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
float* out_preds);
FloatPointer out_preds,
std::size_t out_preds_size);

void gpu_treeshap_interventional(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
const FloatPointer background_data,
std::size_t background_n_rows,
std::size_t background_n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);

void gpu_treeshap_interactions(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);

void gpu_treeshap_taylor_interactions(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);

} // namespace Explainer
} // namespace ML
} // namespace ML
Loading