Skip to content

Commit be7f321

Browse files
[ci][fix] Fix cuda_exp ci (#5438)
* fix cuda_exp ci * fix ci failures introduced by #5279 * cleanup cuda.yml * fix test.sh * clean up test.sh * clean up test.sh * skip lines by cuda_exp in test_register_logger * Update tests/python_package_test/test_utilities.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru> Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
1 parent ef006b7 commit be7f321

File tree

13 files changed

+159
-64
lines changed

13 files changed

+159
-64
lines changed

.github/workflows/cuda.yml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ on:
1111
env:
1212
github_actions: 'true'
1313
os_name: linux
14-
task: cuda
1514
conda_env: test-env
1615

1716
jobs:
1817
test:
19-
name: ${{ matrix.tree_learner }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }})
18+
name: ${{ matrix.task }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }})
2019
runs-on: [self-hosted, linux]
2120
timeout-minutes: 60
2221
strategy:
@@ -27,27 +26,27 @@ jobs:
2726
compiler: gcc
2827
python_version: "3.8"
2928
cuda_version: "11.7.1"
30-
tree_learner: cuda
29+
task: cuda
3130
- method: pip
3231
compiler: clang
3332
python_version: "3.9"
3433
cuda_version: "10.0"
35-
tree_learner: cuda
34+
task: cuda
3635
- method: wheel
3736
compiler: gcc
3837
python_version: "3.10"
3938
cuda_version: "9.0"
40-
tree_learner: cuda
39+
task: cuda
4140
- method: source
4241
compiler: gcc
4342
python_version: "3.8"
4443
cuda_version: "11.7.1"
45-
tree_learner: cuda_exp
44+
task: cuda_exp
4645
- method: pip
4746
compiler: clang
4847
python_version: "3.9"
4948
cuda_version: "10.0"
50-
tree_learner: cuda_exp
49+
task: cuda_exp
5150
steps:
5251
- name: Setup or update software on host machine
5352
run: |
@@ -86,7 +85,7 @@ jobs:
8685
GITHUB_ACTIONS=${{ env.github_actions }}
8786
OS_NAME=${{ env.os_name }}
8887
COMPILER=${{ matrix.compiler }}
89-
TASK=${{ env.task }}
88+
TASK=${{ matrix.task }}
9089
METHOD=${{ matrix.method }}
9190
CONDA_ENV=${{ env.conda_env }}
9291
PYTHON_VERSION=${{ matrix.python_version }}

include/LightGBM/cuda/cuda_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ class CUDAVector {
171171
return data_;
172172
}
173173

174+
const T* RawDataReadOnly() const {
175+
return data_;
176+
}
177+
174178
private:
175179
T* data_;
176180
size_t size_;

include/LightGBM/tree_learner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class TreeLearner {
5050
*/
5151
virtual void ResetConfig(const Config* config) = 0;
5252

53+
/*!
54+
* \brief Reset boosting_on_gpu_
55+
* \param boosting_on_gpu flag for boosting on GPU
56+
*/
57+
virtual void ResetBoostingOnGPU(const bool /*boosting_on_gpu*/) {}
58+
5359
virtual void SetForcedSplit(const Json* forced_split_json) = 0;
5460

5561
/*!

src/boosting/cuda/cuda_score_updater.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ CUDAScoreUpdater::CUDAScoreUpdater(const Dataset* data, int num_tree_per_iterati
2525
has_init_score_ = true;
2626
CopyFromHostToCUDADevice<double>(cuda_score_, init_score, total_size, __FILE__, __LINE__);
2727
} else {
28-
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(num_data_), __FILE__, __LINE__);
28+
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(total_size), __FILE__, __LINE__);
2929
}
3030
SynchronizeCUDADevice(__FILE__, __LINE__);
3131
if (boosting_on_cuda_) {

src/boosting/cuda/cuda_score_updater.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,22 @@ namespace LightGBM {
1111

1212
__global__ void AddScoreConstantKernel(
1313
const double val,
14-
const size_t offset,
1514
const data_size_t num_data,
1615
double* score) {
1716
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
1817
if (data_index < num_data) {
19-
score[data_index + offset] += val;
18+
score[data_index] += val;
2019
}
2120
}
2221

2322
void CUDAScoreUpdater::LaunchAddScoreConstantKernel(const double val, const size_t offset) {
2423
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
2524
Log::Debug("Adding init score = %lf", val);
26-
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
25+
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, num_data_, cuda_score_ + offset);
2726
}
2827

2928
__global__ void MultiplyScoreConstantKernel(
3029
const double val,
31-
const size_t offset,
3230
const data_size_t num_data,
3331
double* score) {
3432
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
@@ -39,7 +37,7 @@ __global__ void MultiplyScoreConstantKernel(
3937

4038
void CUDAScoreUpdater::LaunchMultiplyScoreConstantKernel(const double val, const size_t offset) {
4139
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
42-
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
40+
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, num_data_, cuda_score_ + offset);
4341
}
4442

4543
} // namespace LightGBM

src/boosting/gbdt.cpp

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ GBDT::GBDT()
4141
average_output_ = false;
4242
tree_learner_ = nullptr;
4343
linear_tree_ = false;
44+
gradients_pointer_ = nullptr;
45+
hessians_pointer_ = nullptr;
46+
boosting_on_gpu_ = false;
4447
}
4548

4649
GBDT::~GBDT() {
@@ -95,9 +98,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
9598

9699
is_constant_hessian_ = GetIsConstHessian(objective_function);
97100

98-
const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
101+
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
99102
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type,
100-
config_.get(), boosting_on_gpu));
103+
config_.get(), boosting_on_gpu_));
101104

102105
// init tree learner
103106
tree_learner_->Init(train_data_, is_constant_hessian_);
@@ -112,7 +115,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
112115

113116
#ifdef USE_CUDA_EXP
114117
if (config_->device_type == std::string("cuda_exp")) {
115-
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu));
118+
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_));
116119
} else {
117120
#endif // USE_CUDA_EXP
118121
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
@@ -123,9 +126,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
123126
num_data_ = train_data_->num_data();
124127
// create buffer for gradients and Hessians
125128
if (objective_function_ != nullptr) {
126-
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
129+
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
127130
#ifdef USE_CUDA_EXP
128-
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) {
131+
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
132+
if (gradients_pointer_ != nullptr) {
133+
CHECK_NOTNULL(hessians_pointer_);
134+
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
135+
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
136+
}
129137
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
130138
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
131139
} else {
@@ -137,17 +145,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
137145
#ifdef USE_CUDA_EXP
138146
}
139147
#endif // USE_CUDA_EXP
140-
#ifndef USE_CUDA_EXP
141-
}
142-
#else // USE_CUDA_EXP
143-
} else {
144-
if (config_->device_type == std::string("cuda_exp")) {
145-
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
146-
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
147-
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
148-
}
148+
} else if (config_->boosting == std::string("goss")) {
149+
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
150+
gradients_.resize(total_size);
151+
hessians_.resize(total_size);
152+
gradients_pointer_ = gradients_.data();
153+
hessians_pointer_ = hessians_.data();
149154
}
150-
#endif // USE_CUDA_EXP
155+
151156
// get max feature index
152157
max_feature_idx_ = train_data_->num_total_features() - 1;
153158
// get label index
@@ -440,23 +445,36 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
440445
Boosting();
441446
gradients = gradients_pointer_;
442447
hessians = hessians_pointer_;
443-
#ifndef USE_CUDA_EXP
444-
}
445-
#else // USE_CUDA_EXP
446448
} else {
447-
if (config_->device_type == std::string("cuda_exp")) {
448-
const size_t total_size = static_cast<size_t>(num_data_ * num_class_);
449-
CopyFromHostToCUDADevice<score_t>(gradients_pointer_, gradients, total_size, __FILE__, __LINE__);
450-
CopyFromHostToCUDADevice<score_t>(hessians_pointer_, hessians, total_size, __FILE__, __LINE__);
449+
// use customized objective function
450+
CHECK(objective_function_ == nullptr);
451+
if (config_->boosting == std::string("goss")) {
452+
// need to copy customized gradients when using GOSS
453+
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
454+
#pragma omp parallel for schedule(static)
455+
for (int64_t i = 0; i < total_size; ++i) {
456+
gradients_[i] = gradients[i];
457+
hessians_[i] = hessians[i];
458+
}
459+
CHECK_EQ(gradients_pointer_, gradients_.data());
460+
CHECK_EQ(hessians_pointer_, hessians_.data());
451461
gradients = gradients_pointer_;
452462
hessians = hessians_pointer_;
453463
}
454464
}
455-
#endif // USE_CUDA_EXP
456465

457466
// bagging logic
458467
Bagging(iter_);
459468

469+
if (gradients != nullptr && is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_ && config_->boosting != std::string("goss")) {
470+
// allocate gradients_ and hessians_ for copy gradients for using data subset
471+
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
472+
gradients_.resize(total_size);
473+
hessians_.resize(total_size);
474+
gradients_pointer_ = gradients_.data();
475+
hessians_pointer_ = hessians_.data();
476+
}
477+
460478
bool should_continue = false;
461479
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
462480
const size_t offset = static_cast<size_t>(cur_tree_id) * num_data_;
@@ -465,7 +483,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
465483
auto grad = gradients + offset;
466484
auto hess = hessians + offset;
467485
// need to copy gradients for bagging subset.
468-
if (is_use_subset_ && bag_data_cnt_ < num_data_ && config_->device_type != std::string("cuda_exp")) {
486+
if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) {
469487
for (int i = 0; i < bag_data_cnt_; ++i) {
470488
gradients_pointer_[offset + i] = grad[bag_data_indices_[i]];
471489
hessians_pointer_[offset + i] = hess[bag_data_indices_[i]];
@@ -591,13 +609,12 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
591609

592610
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
593611
#ifdef USE_CUDA_EXP
594-
const bool boosting_on_cuda = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
595612
const bool evaluation_on_cuda = metric->IsCUDAMetric();
596-
if ((boosting_on_cuda && evaluation_on_cuda) || (!boosting_on_cuda && !evaluation_on_cuda)) {
613+
if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) {
597614
#endif // USE_CUDA_EXP
598615
return metric->Eval(score, objective_function_);
599616
#ifdef USE_CUDA_EXP
600-
} else if (boosting_on_cuda && !evaluation_on_cuda) {
617+
} else if (boosting_on_gpu_ && !evaluation_on_cuda) {
601618
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_);
602619
if (total_size > host_score_.size()) {
603620
host_score_.resize(total_size, 0.0f);
@@ -804,17 +821,16 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
804821
}
805822
training_metrics_.shrink_to_fit();
806823

807-
#ifdef USE_CUDA_EXP
808-
const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
809-
#endif // USE_CUDA_EXP
824+
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
825+
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);
810826

811827
if (train_data != train_data_) {
812828
train_data_ = train_data;
813829
// not same training data, need reset score and others
814830
// create score tracker
815831
#ifdef USE_CUDA_EXP
816832
if (config_->device_type == std::string("cuda_exp")) {
817-
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu));
833+
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_));
818834
} else {
819835
#endif // USE_CUDA_EXP
820836
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
@@ -834,9 +850,14 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
834850

835851
// create buffer for gradients and hessians
836852
if (objective_function_ != nullptr) {
837-
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
853+
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
838854
#ifdef USE_CUDA_EXP
839-
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) {
855+
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
856+
if (gradients_pointer_ != nullptr) {
857+
CHECK_NOTNULL(hessians_pointer_);
858+
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
859+
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
860+
}
840861
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
841862
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
842863
} else {
@@ -848,6 +869,12 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
848869
#ifdef USE_CUDA_EXP
849870
}
850871
#endif // USE_CUDA_EXP
872+
} else if (config_->boosting == std::string("goss")) {
873+
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
874+
gradients_.resize(total_size);
875+
hessians_.resize(total_size);
876+
gradients_pointer_ = gradients_.data();
877+
hessians_pointer_ = hessians_.data();
851878
}
852879

853880
max_feature_idx_ = train_data_->num_total_features() - 1;
@@ -879,6 +906,10 @@ void GBDT::ResetConfig(const Config* config) {
879906
if (tree_learner_ != nullptr) {
880907
tree_learner_->ResetConfig(new_config.get());
881908
}
909+
910+
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
911+
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);
912+
882913
if (train_data_ != nullptr) {
883914
ResetBaggingConfig(new_config.get(), false);
884915
}
@@ -953,10 +984,16 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
953984
need_re_bagging_ = true;
954985

955986
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
956-
if (objective_function_ == nullptr) {
957-
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
987+
// resize gradient vectors to copy the customized gradients for goss or bagging with subset
988+
if (objective_function_ != nullptr) {
989+
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
958990
#ifdef USE_CUDA_EXP
959-
if (config_->device_type == std::string("cuda_exp") && objective_function_ != nullptr && objective_function_->IsCUDAObjective()) {
991+
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
992+
if (gradients_pointer_ != nullptr) {
993+
CHECK_NOTNULL(hessians_pointer_);
994+
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
995+
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
996+
}
960997
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
961998
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
962999
} else {
@@ -968,6 +1005,12 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
9681005
#ifdef USE_CUDA_EXP
9691006
}
9701007
#endif // USE_CUDA_EXP
1008+
} else if (config_->boosting == std::string("goss")) {
1009+
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
1010+
gradients_.resize(total_size);
1011+
hessians_.resize(total_size);
1012+
gradients_pointer_ = gradients_.data();
1013+
hessians_pointer_ = hessians_.data();
9711014
}
9721015
}
9731016
} else {

src/boosting/gbdt.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ class GBDT : public GBDTBase {
504504
score_t* gradients_pointer_;
505505
/*! \brief Pointer to hessian vector, can be on CPU or GPU */
506506
score_t* hessians_pointer_;
507+
/*! \brief Whether boosting is done on GPU, used for cuda_exp */
508+
bool boosting_on_gpu_;
507509
#ifdef USE_CUDA_EXP
508510
/*! \brief Buffer for scores when boosting is on GPU but evaluation is not, used only with cuda_exp */
509511
mutable std::vector<double> host_score_;

src/boosting/rf.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class RF : public GBDT {
116116
auto hess = hessians + offset;
117117

118118
// need to copy gradients for bagging subset.
119-
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
119+
if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) {
120120
for (int i = 0; i < bag_data_cnt_; ++i) {
121121
tmp_grad_[i] = grad[bag_data_indices_[i]];
122122
tmp_hess_[i] = hess[bag_data_indices_[i]];

0 commit comments

Comments
 (0)