Skip to content

Commit

Permalink
[CUDA] Add feature interaction constraint for cuda_exp (fix microsoft…
Browse files Browse the repository at this point in the history
…#4785) (microsoft#5474)

* add feature interaction constraint for cuda_exp

* test feature interaction constraints for cuda_exp

* remove useless check

* update comment
  • Loading branch information
shiyu1994 authored Sep 7, 2022
1 parent a46c68f commit 1444a74
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 7 deletions.
4 changes: 4 additions & 0 deletions include/LightGBM/cuda/cuda_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ class CUDATree : public Tree {

void LaunchAddBiasKernel(const double val);

void RecordBranchFeatures(const int left_leaf_index,
const int right_leaf_index,
const int real_feature_index);

int* cuda_left_child_;
int* cuda_right_child_;
int* cuda_split_feature_inner_;
Expand Down
12 changes: 12 additions & 0 deletions src/io/cuda/cuda_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ int CUDATree::Split(const int leaf_index,
const MissingType missing_type,
const CUDASplitInfo* cuda_split_info) {
LaunchSplitKernel(leaf_index, real_feature_index, real_threshold, missing_type, cuda_split_info);
RecordBranchFeatures(leaf_index, num_leaves_, real_feature_index);
++num_leaves_;
return num_leaves_ - 1;
}
Expand All @@ -235,9 +236,20 @@ int CUDATree::SplitCategorical(const int leaf_index,
cuda_bitset_inner_.PushBack(cuda_bitset_inner, cuda_bitset_inner_len);
++num_leaves_;
++num_cat_;
RecordBranchFeatures(leaf_index, num_leaves_, real_feature_index);
return num_leaves_ - 1;
}

void CUDATree::RecordBranchFeatures(const int left_leaf_index,
const int right_leaf_index,
const int real_feature_index) {
if (track_branch_features_) {
branch_features_[right_leaf_index] = branch_features_[left_leaf_index];
branch_features_[right_leaf_index].push_back(real_feature_index);
branch_features_[left_leaf_index].push_back(real_feature_index);
}
}

void CUDATree::AddPredictionToScore(const Dataset* data,
data_size_t num_data,
double* score) const {
Expand Down
17 changes: 17 additions & 0 deletions src/treelearner/cuda/cuda_best_split_finder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ CUDABestSplitFinder::CUDABestSplitFinder(
const hist_t* cuda_hist,
const Dataset* train_data,
const std::vector<uint32_t>& feature_hist_offsets,
const bool select_features_by_node,
const Config* config):
num_features_(train_data->num_features()),
num_leaves_(config->num_leaves),
Expand All @@ -36,6 +37,7 @@ CUDABestSplitFinder::CUDABestSplitFinder(
use_smoothing_(config->path_smooth > 0),
path_smooth_(config->path_smooth),
num_total_bin_(feature_hist_offsets.empty() ? 0 : static_cast<int>(feature_hist_offsets.back())),
select_features_by_node_(select_features_by_node),
cuda_hist_(cuda_hist) {
InitFeatureMetaInfo(train_data);
cuda_leaf_best_split_info_ = nullptr;
Expand Down Expand Up @@ -105,6 +107,11 @@ void CUDABestSplitFinder::Init() {
AllocateCUDAMemory<data_size_t>(&cuda_feature_hist_index_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
}
}

if (select_features_by_node_) {
is_feature_used_by_smaller_node_.Resize(num_features_);
is_feature_used_by_larger_node_.Resize(num_features_);
}
}

void CUDABestSplitFinder::InitCUDAFeatureMetaInfo() {
Expand Down Expand Up @@ -364,6 +371,16 @@ void CUDABestSplitFinder::AllocateCatVectors(CUDASplitInfo* cuda_split_infos, ui
LaunchAllocateCatVectorsKernel(cuda_split_infos, cat_threshold_vec, cat_threshold_real_vec, len);
}

void CUDABestSplitFinder::SetUsedFeatureByNode(const std::vector<int8_t>& is_feature_used_by_smaller_node,
const std::vector<int8_t>& is_feature_used_by_larger_node) {
if (select_features_by_node_) {
CopyFromHostToCUDADevice<int8_t>(is_feature_used_by_smaller_node_.RawData(),
is_feature_used_by_smaller_node.data(), is_feature_used_by_smaller_node.size(), __FILE__, __LINE__);
CopyFromHostToCUDADevice<int8_t>(is_feature_used_by_larger_node_.RawData(),
is_feature_used_by_larger_node.data(), is_feature_used_by_larger_node.size(), __FILE__, __LINE__);
}
}

} // namespace LightGBM

#endif // USE_CUDA_EXP
15 changes: 10 additions & 5 deletions src/treelearner/cuda/cuda_best_split_finder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,6 @@ __global__ void FindBestSplitsForLeafKernel_GlobalMemory(
is_larger_leaf_valid

#define FindBestSplitsForLeafKernel_ARGS \
cuda_is_feature_used_bytree_, \
num_tasks_, \
cuda_split_find_tasks_.RawData(), \
cuda_randoms_.RawData(), \
Expand Down Expand Up @@ -1430,29 +1429,35 @@ void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner1(LaunchFindBest

template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBestSplitsForLeafKernel_PARAMS) {
const int8_t* is_feature_used_by_smaller_node = cuda_is_feature_used_bytree_;
const int8_t* is_feature_used_by_larger_node = cuda_is_feature_used_bytree_;
if (select_features_by_node_) {
is_feature_used_by_smaller_node = is_feature_used_by_smaller_node_.RawData();
is_feature_used_by_larger_node = is_feature_used_by_larger_node_.RawData();
}
if (!use_global_memory_) {
if (is_smaller_leaf_valid) {
FindBestSplitsForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, false>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[0]>>>
(FindBestSplitsForLeafKernel_ARGS);
(is_feature_used_by_smaller_node, FindBestSplitsForLeafKernel_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (is_larger_leaf_valid) {
FindBestSplitsForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, true>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[1]>>>
(FindBestSplitsForLeafKernel_ARGS);
(is_feature_used_by_larger_node, FindBestSplitsForLeafKernel_ARGS);
}
} else {
if (is_smaller_leaf_valid) {
FindBestSplitsForLeafKernel_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING, false>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[0]>>>
(FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS);
(is_feature_used_by_smaller_node, FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (is_larger_leaf_valid) {
FindBestSplitsForLeafKernel_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING, true>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[1]>>>
(FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS);
(is_feature_used_by_larger_node, FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS);
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions src/treelearner/cuda/cuda_best_split_finder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CUDABestSplitFinder {
const hist_t* cuda_hist,
const Dataset* train_data,
const std::vector<uint32_t>& feature_hist_offsets,
const bool select_features_by_node,
const Config* config);

~CUDABestSplitFinder();
Expand Down Expand Up @@ -88,6 +89,9 @@ class CUDABestSplitFinder {

void ResetConfig(const Config* config, const hist_t* cuda_hist);

void SetUsedFeatureByNode(const std::vector<int8_t>& is_feature_used_by_smaller_node,
const std::vector<int8_t>& is_feature_used_by_larger_node);

private:
#define LaunchFindBestSplitsForLeafKernel_PARAMS \
const CUDALeafSplitsStruct* smaller_leaf_splits, \
Expand Down Expand Up @@ -172,6 +176,8 @@ class CUDABestSplitFinder {
int max_num_categorical_bin_;
// marks whether a feature is categorical
std::vector<int8_t> is_categorical_;
// whether need to select features by node
bool select_features_by_node_;

// CUDA memory, held by this object
// for per leaf best split information
Expand All @@ -195,6 +201,9 @@ class CUDABestSplitFinder {
int max_num_categories_in_split_;
// used for extremely randomized trees
CUDAVector<CUDARandom> cuda_randoms_;
// features used by node
CUDAVector<int8_t> is_feature_used_by_smaller_node_;
CUDAVector<int8_t> is_feature_used_by_larger_node_;

// CUDA memory, held by other object
const hist_t* cuda_hist_;
Expand Down
18 changes: 17 additions & 1 deletion src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
cuda_histogram_constructor_->cuda_hist_pointer()));
cuda_data_partition_->Init();

select_features_by_node_ = !config_->interaction_constraints_vector.empty() || config_->feature_fraction_bynode < 1.0;
cuda_best_split_finder_.reset(new CUDABestSplitFinder(cuda_histogram_constructor_->cuda_hist(),
train_data_, this->share_state_->feature_hist_offsets(), config_));
train_data_, this->share_state_->feature_hist_offsets(), select_features_by_node_, config_));
cuda_best_split_finder_->Init();

leaf_best_split_feature_.resize(config_->num_leaves, -1);
Expand Down Expand Up @@ -149,6 +150,9 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
sum_hessians_in_larger_leaf);
global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf");
global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");

SelectFeatureByNode(tree.get());

cuda_best_split_finder_->FindBestSplitsForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
Expand Down Expand Up @@ -464,6 +468,18 @@ void CUDASingleGPUTreeLearner::ResetBoostingOnGPU(const bool boosting_on_cuda) {
}
}

void CUDASingleGPUTreeLearner::SelectFeatureByNode(const Tree* tree) {
if (select_features_by_node_) {
// use feature interaction constraint or sample features by node
const std::vector<int8_t>& is_feature_used_by_smaller_node = col_sampler_.GetByNode(tree, smaller_leaf_index_);
std::vector<int8_t> is_feature_used_by_larger_node;
if (larger_leaf_index_ >= 0) {
is_feature_used_by_larger_node = col_sampler_.GetByNode(tree, larger_leaf_index_);
}
cuda_best_split_finder_->SetUsedFeatureByNode(is_feature_used_by_smaller_node, is_feature_used_by_larger_node);
}
}

#ifdef DEBUG
void CUDASingleGPUTreeLearner::CheckSplitValid(
const int left_leaf,
Expand Down
4 changes: 4 additions & 0 deletions src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {

void AllocateBitset();

void SelectFeatureByNode(const Tree* tree);

#ifdef DEUBG
void CheckSplitValid(
const int left_leaf, const int right_leaf,
Expand Down Expand Up @@ -100,6 +102,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
int best_leaf_index_;
int num_cat_threshold_;
bool has_categorical_feature_;
// whether need to select features by node
bool select_features_by_node_;

std::vector<int> categorical_bin_to_value_;
std::vector<int> categorical_bin_offsets_;
Expand Down
1 change: 0 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,7 +3120,6 @@ def _imptcs_to_numpy(X, impcts_dict):
assert tree_df.loc[0, col] is None


@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
def test_interaction_constraints():
X, y = load_boston(return_X_y=True)
num_features = X.shape[1]
Expand Down

0 comments on commit 1444a74

Please sign in to comment.