Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Api updates #23

Merged
merged 4 commits into from
Dec 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update API to accept thrust style iterator
  • Loading branch information
RAMitchell committed Dec 2, 2020
commit fa033ca52e60da1a711c988177f9150cae5b1b4f
195 changes: 95 additions & 100 deletions GPUTreeShap/gpu_treeshap.h
Original file line number Diff line number Diff line change
Expand Up @@ -977,44 +977,40 @@ void ComputeBias(const PathVectorT& device_paths, DoubleVectorT* bias) {
* Compute feature contributions on the GPU given a set of unique paths through a tree ensemble
* and a dataset. Uses device memory proportional to the tree ensemble size.
*
* \exception std::invalid_argument Thrown when an invalid argument error condition occurs.
* \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or
* stl iterator/raw pointer for host memory.
* \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or
* stl
* iterator/raw pointer for host memory. Value type must be floating
* point.
* \tparam DatasetT User-specified dataset container.
* \tparam DeviceAllocatorT Optional thrust style allocator.
*
* \param X Thin wrapper over a dataset allocated in device memory. X
* should be trivially copyable as a kernel parameter (i.e.
* contain only pointers to actual data) and must implement the
* methods NumRows()/NumCols()/GetElement(size_t row_idx, size_t
* col_idx) as __device__ functions. GetElement may return NaN
* where the feature value is missing.
* \param begin Iterator to paths, where separate paths are delineated by
* PathElement.path_idx. Each unique path should contain 1 root
* with feature_idx = -1 and zero_fraction = 1.0. The ordering of
* path elements inside a unique path does not matter - the result
* will be the same. Paths may contain duplicate features. See the
* PathElement class for more information.
* \param end Path end iterator.
* \param num_groups Number of output groups. In multiclass classification the
* algorithm outputs feature contributions per output class.
* \param [in,out] phis_out Device memory buffer for returning the feature contributions.
* The last feature column contains the bias term. Feature
* contributions can be retrieved by phis_out[(row_idx *
* num_groups + group) * (X.NumCols() + 1) + feature_idx]. Results
* are added to the input buffer without zeroing memory - do not
* pass uninitialised memory.
* \param phis_out_length Length of the phis_out for bounds checking. Must be at least of
* size X.NumRows() * (X.NumCols() + 1) * num_groups.
*
* \tparam DatasetT User-specified dataset container.
*
* \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or stl
* iterator/raw pointer for host memory.
* \param X Thin wrapper over a dataset allocated in device memory. X should be trivially
* copyable as a kernel parameter (i.e. contain only pointers to actual data) and
* must implement the methods NumRows()/NumCols()/GetElement(size_t row_idx,
* size_t col_idx) as __device__ functions. GetElement may return NaN where the
* feature value is missing.
* \param begin Iterator to paths, where separate paths are delineated by
* PathElement.path_idx. Each unique path should contain 1 root with feature_idx =
* -1 and zero_fraction = 1.0. The ordering of path elements inside a unique path
* does not matter - the result will be the same. Paths may contain duplicate
* features. See the PathElement class for more information.
* \param end Path end iterator.
* \param num_groups Number of output groups. In multiclass classification the algorithm outputs
* feature contributions per output class.
* \param phis_begin Begin iterator for output phis.
* \param phis_end End iterator for output phis.
*/
template <typename DeviceAllocatorT = thrust::device_allocator<int>,
typename DatasetT, typename PathIteratorT>
typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
void GPUTreeShap(DatasetT X, PathIteratorT begin, PathIteratorT end,
size_t num_groups, float* phis_out, size_t phis_out_length) {
size_t num_groups, PhiIteratorT phis_begin,
PhiIteratorT phis_end) {
if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;

if (phis_out_length < X.NumRows() * (X.NumCols() + 1) * num_groups) {
if (phis_end - phis_begin < X.NumRows() * (X.NumCols() + 1) * num_groups) {
throw std::invalid_argument(
"phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
"num_groups");
Expand All @@ -1025,7 +1021,7 @@ void GPUTreeShap(DatasetT X, PathIteratorT begin, PathIteratorT end,
using path_vector = detail::RebindVector<PathElement, DeviceAllocatorT>;

// Compute the global bias
double_vector temp_phi(phis_out_length, 0.0);
double_vector temp_phi(phis_end - phis_begin, 0.0);
path_vector device_paths(begin, end);
double_vector bias(num_groups, 0.0);
detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT>(
Expand All @@ -1049,48 +1045,46 @@ void GPUTreeShap(DatasetT X, PathIteratorT begin, PathIteratorT end,
detail::ComputeShap(X, device_bin_segments, deduplicated_paths, num_groups,
temp_phi.data().get());
thrust::copy(temp_phi.begin(), temp_phi.end(),
thrust::device_pointer_cast(phis_out));
phis_begin);
}

/*!
* Compute feature interaction contributions on the GPU given a set of unique paths through a tree
* ensemble and a dataset. Uses device memory proportional to the tree ensemble size.
*
* \tparam DatasetT User-specified dataset container.
*
* \param X Thin wrapper over a dataset allocated in device memory. X
* should be trivially copyable as a kernel parameter (i.e.
* contain only pointers to actual data) and must implement the
* methods NumRows()/NumCols()/GetElement(size_t row_idx, size_t
* col_idx) as __device__ functions. GetElement may return NaN
* where the feature value is missing.
* \param begin Iterator to paths, where separate paths are delineated by
* PathElement.path_idx. Each unique path should contain 1 root
* with feature_idx = -1 and zero_fraction = 1.0. The ordering of
* path elements inside a unique path does not matter - the result
* will be the same. Paths may contain duplicate features. See the
* PathElement class for more information.
* \param end Path end iterator.
* \param num_groups Number of output groups. In multiclass classification the
* algorithm outputs feature contributions per output class.
* \param [in,out] phis_out Device memory buffer for returning the feature interaction
* contributions. The last feature column contains the bias term.
* Results are added to the input buffer without zeroing memory -
* do not pass uninitialised memory.
* \param phis_out_length Length of the phis_out for bounds checking. Must be at least
* size X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) *
* num_groups. *.
* \exception std::invalid_argument Thrown when an invalid argument error condition occurs.
* \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or
* stl
* iterator/raw pointer for host memory. Value type must be floating
* point.
* \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or
* stl iterator/raw pointer for host memory.
* \tparam DatasetT User-specified dataset container.
* \tparam DeviceAllocatorT Optional thrust style allocator.
*
* \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or stl
* iterator/raw pointer for host memory.
* \param X Thin wrapper over a dataset allocated in device memory. X should be trivially
* copyable as a kernel parameter (i.e. contain only pointers to actual data) and
* must implement the methods NumRows()/NumCols()/GetElement(size_t row_idx,
* size_t col_idx) as __device__ functions. GetElement may return NaN where the
* feature value is missing.
* \param begin Iterator to paths, where separate paths are delineated by
* PathElement.path_idx. Each unique path should contain 1 root with feature_idx =
* -1 and zero_fraction = 1.0. The ordering of path elements inside a unique path
* does not matter - the result will be the same. Paths may contain duplicate
* features. See the PathElement class for more information.
* \param end Path end iterator.
* \param num_groups Number of output groups. In multiclass classification the algorithm outputs
* feature contributions per output class.
* \param phis_begin Begin iterator for output phis.
* \param phis_end End iterator for output phis.
*/
template <typename DeviceAllocatorT = thrust::device_allocator<int>,
typename DatasetT, typename PathIteratorT>
typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
void GPUTreeShapInteractions(DatasetT X, PathIteratorT begin, PathIteratorT end,
size_t num_groups, float* phis_out,
size_t phis_out_length) {
size_t num_groups, PhiIteratorT phis_begin,
PhiIteratorT phis_end) {
if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
if (phis_out_length <
if (phis_end - phis_begin <
X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) {
throw std::invalid_argument(
"phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
Expand All @@ -1103,7 +1097,7 @@ void GPUTreeShapInteractions(DatasetT X, PathIteratorT begin, PathIteratorT end,
using path_vector = detail::RebindVector<PathElement, DeviceAllocatorT>;

// Compute the global bias
double_vector temp_phi(phis_out_length, 0.0);
double_vector temp_phi(phis_end - phis_begin , 0.0);
path_vector device_paths(begin, end);
double_vector bias(num_groups, 0.0);
detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT>(
Expand All @@ -1127,49 +1121,51 @@ void GPUTreeShapInteractions(DatasetT X, PathIteratorT begin, PathIteratorT end,

detail::ComputeShapInteractions(X, device_bin_segments, deduplicated_paths,
num_groups, temp_phi.data().get());
thrust::copy(temp_phi.begin(), temp_phi.end(),
thrust::device_pointer_cast(phis_out));
thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
}

/*!
* Compute feature interaction contributions using the Shapley Taylor index on the GPU, given a set of unique paths through a tree
* ensemble and a dataset. Uses device memory proportional to the tree ensemble size.
*
* \tparam DatasetT User-specified dataset container.
*
* \param X Thin wrapper over a dataset allocated in device memory. X
* should be trivially copyable as a kernel parameter (i.e.
* contain only pointers to actual data) and must implement the
* methods NumRows()/NumCols()/GetElement(size_t row_idx, size_t
* col_idx) as __device__ functions. GetElement may return NaN
* where the feature value is missing.
* \param begin Iterator to paths, where separate paths are delineated by
* PathElement.path_idx. Each unique path should contain 1 root
* with feature_idx = -1 and zero_fraction = 1.0. The ordering of
* path elements inside a unique path does not matter - the result
* will be the same. Paths may contain duplicate features. See the
* PathElement class for more information.
* \param end Path end iterator.
* \param num_groups Number of output groups. In multiclass classification the
* algorithm outputs feature contributions per output class.
* \param [in,out] phis_out Device memory buffer for returning the feature interaction
* contributions. The last feature column contains the bias term.
* Results are added to the input buffer without zeroing memory -
* do not pass uninitialised memory.
* \param phis_out_length Length of the phis_out for bounds checking. Must be at least
* size X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) *
* num_groups. *.
* Compute feature interaction contributions using the Shapley Taylor index on the GPU, given a
* set of unique paths through a tree ensemble and a dataset. Uses device memory proportional to
* the tree ensemble size.
*
* \exception std::invalid_argument Thrown when an invalid argument error condition occurs.
* \tparam DeviceAllocatorT Optional thrust style allocator.
* \tparam DatasetT User-specified dataset container.
* \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or stl
* iterator/raw pointer for host memory.
* \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or stl
* iterator/raw pointer for host memory. Value type must be floating point.
*
* \param X Thin wrapper over a dataset allocated in device memory. X should be trivially
* copyable as a kernel parameter (i.e. contain only pointers to actual data) and
* must implement the methods NumRows()/NumCols()/GetElement(size_t row_idx,
* size_t col_idx) as __device__ functions. GetElement may return NaN where the
* feature value is missing.
* \param begin Iterator to paths, where separate paths are delineated by
* PathElement.path_idx. Each unique path should contain 1 root with feature_idx =
* -1 and zero_fraction = 1.0. The ordering of path elements inside a unique path
* does not matter - the result will be the same. Paths may contain duplicate
* features. See the PathElement class for more information.
* \param end Path end iterator.
* \param num_groups Number of output groups. In multiclass classification the algorithm outputs
* feature contributions per output class.
* \param phis_begin Begin iterator for output phis.
* \param phis_end End iterator for output phis.
*/
template <typename DeviceAllocatorT = thrust::device_allocator<int>,
typename DatasetT, typename PathIteratorT>
typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
void GPUTreeShapTaylorInteractions(DatasetT X, PathIteratorT begin,
PathIteratorT end, size_t num_groups,
float* phis_out, size_t phis_out_length) {
PhiIteratorT phis_begin,
PhiIteratorT phis_end) {
using phis_type = typename std::iterator_traits<PhiIteratorT>::value_type;
static_assert(std::is_floating_point<phis_type>::value,
"Phis type must be floating point");

if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
if (phis_out_length <

if (phis_end - phis_begin <
X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) {
throw std::invalid_argument(
"phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
Expand All @@ -1182,7 +1178,7 @@ void GPUTreeShapTaylorInteractions(DatasetT X, PathIteratorT begin,
using path_vector = detail::RebindVector<PathElement, DeviceAllocatorT>;

// Compute the global bias
double_vector temp_phi(phis_out_length, 0.0);
double_vector temp_phi(phis_end - phis_begin, 0.0);
path_vector device_paths(begin, end);
double_vector bias(num_groups, 0.0);
detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT>(
Expand All @@ -1207,7 +1203,6 @@ void GPUTreeShapTaylorInteractions(DatasetT X, PathIteratorT begin,
detail::ComputeShapTaylorInteractions(X, device_bin_segments,
deduplicated_paths, num_groups,
temp_phi.data().get());
thrust::copy(temp_phi.begin(), temp_phi.end(),
thrust::device_pointer_cast(phis_out));
thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
}
}; // namespace gpu_treeshap
} // namespace gpu_treeshap
Loading