From 8487d0a2247acf01e6c19f351c4a0f1e63937372 Mon Sep 17 00:00:00 2001 From: Huan Zhang Date: Sun, 23 Apr 2017 04:28:52 -0700 Subject: [PATCH] Fix compilation problems with MSVC (#443) * Fix warnings when compiled with -pedantic * add -DBOOST_ALL_NO_LIB for windows build * fix some more MSVC warnings * Break OpenCL string literal to smaller pieces to avoid error C2026 of MSVC The string was longer than the limit of 16380 single-byte characters. This affects Visual Studio 2005 - 2015. Untested on VS 2017. --- CMakeLists.txt | 4 ++ src/treelearner/gpu_tree_learner.cpp | 93 ++++++++++++++-------------- src/treelearner/gpu_tree_learner.h | 15 +---- src/treelearner/ocl/histogram16.cl | 18 ++++++ src/treelearner/ocl/histogram256.cl | 19 +++++- src/treelearner/ocl/histogram64.cl | 19 +++++- 6 files changed, 106 insertions(+), 62 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d0c19ab5c54c..b39237da0474 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,10 @@ if(USE_GPU) MESSAGE(STATUS "OpenCL include directory:" ${OpenCL_INCLUDE_DIRS}) find_package(Boost 1.56.0 COMPONENTS filesystem system REQUIRED) include_directories(${Boost_INCLUDE_DIRS}) + if (WIN32) + # disable autolinking in boost + ADD_DEFINITIONS(-DBOOST_ALL_NO_LIB) + endif() ADD_DEFINITIONS(-DUSE_GPU) endif(USE_GPU) diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp index e4b732e35376..cbf223441454 100644 --- a/src/treelearner/gpu_tree_learner.cpp +++ b/src/treelearner/gpu_tree_learner.cpp @@ -104,7 +104,7 @@ int GPUTreeLearner::GetNumWorkgroupsPerFeature(data_size_t leaf_num_data) { // we roughly want 256 workgroups per device, and we have num_dense_feature4_ feature tuples. // also guarantee that there are at least 2K examples per workgroup double x = 256.0 / num_dense_feature4_; - int exp_workgroups_per_feature = ceil(log2(x)); + int exp_workgroups_per_feature = (int)ceil(log2(x)); double t = leaf_num_data / 1024.0; #if GPU_DEBUG >= 4 printf("Computing histogram for %d examples and (%d * %d) feature groups\n", leaf_num_data, dword_features_, num_dense_feature4_); @@ -191,7 +191,7 @@ void GPUTreeLearner::GPUHistogram(data_size_t leaf_num_data, bool use_all_featur } template -void GPUTreeLearner::WaitAndGetHistograms(HistogramBinEntry* histograms, const std::vector& is_feature_used) { +void GPUTreeLearner::WaitAndGetHistograms(HistogramBinEntry* histograms) { HistType* hist_outputs = (HistType*) host_histogram_outputs_; // when the output is ready, the computation is done histograms_wait_obj_.wait(); @@ -207,7 +207,7 @@ void GPUTreeLearner::WaitAndGetHistograms(HistogramBinEntry* histograms, const s for (int j = 0; j < bin_size; ++j) { old_histogram_array[j].sum_gradients = hist_outputs[i * device_bin_size_+ j].sum_gradients; old_histogram_array[j].sum_hessians = hist_outputs[i * device_bin_size_ + j].sum_hessians; - old_histogram_array[j].cnt = hist_outputs[i * device_bin_size_ + j].cnt; + old_histogram_array[j].cnt = (data_size_t)hist_outputs[i * device_bin_size_ + j].cnt; } } else { @@ -224,7 +224,7 @@ void GPUTreeLearner::WaitAndGetHistograms(HistogramBinEntry* histograms, const s } old_histogram_array[j].sum_gradients = sum_g; old_histogram_array[j].sum_hessians = sum_h; - old_histogram_array[j].cnt = cnt; + old_histogram_array[j].cnt = (data_size_t)cnt; } } } @@ -323,11 +323,12 @@ void GPUTreeLearner::AllocateGPUMemory() { device_histogram_outputs_ = boost::compute::buffer(ctx_, num_dense_feature4_ * dword_features_ * device_bin_size_ * hist_bin_entry_sz_, boost::compute::memory_object::write_only | boost::compute::memory_object::alloc_host_ptr, nullptr); // find the dense feature-groups and group then into Feature4 data structure (several feature-groups packed into 4 bytes) - int i, k, copied_feature4 = 0, dense_ind[dword_features_]; - for (i = 0, k = 0; i < num_feature_groups_; ++i) { + int k = 0, copied_feature4 = 0; + std::vector dense_dword_ind(dword_features_); + for (int i = 0; i < num_feature_groups_; ++i) { // looking for dword_features_ non-sparse feature-groups if (ordered_bins_[i] == nullptr) { - dense_ind[k] = i; + dense_dword_ind[k] = i; // decide if we need to redistribute the bin double t = device_bin_size_ / (double)train_data_->FeatureGroupNumBin(i); // multiplier must be a power of 2 @@ -345,7 +346,7 @@ void GPUTreeLearner::AllocateGPUMemory() { if (k == dword_features_) { k = 0; for (int j = 0; j < dword_features_; ++j) { - dense_feature_group_map_.push_back(dense_ind[j]); + dense_feature_group_map_.push_back(dense_dword_ind[j]); } copied_feature4++; } @@ -369,7 +370,7 @@ void GPUTreeLearner::AllocateGPUMemory() { } // building Feature4 bundles; each thread handles dword_features_ features #pragma omp parallel for schedule(static) - for (unsigned int i = 0; i < dense_feature_group_map_.size() / dword_features_; ++i) { + for (int i = 0; i < (int)(dense_feature_group_map_.size() / dword_features_); ++i) { int tid = omp_get_thread_num(); Feature4* host4 = host4_ptrs[tid]; auto dense_ind = dense_feature_group_map_.begin() + i * dword_features_; @@ -401,14 +402,14 @@ void GPUTreeLearner::AllocateGPUMemory() { *static_cast(bin_iters[6]), *static_cast(bin_iters[7])}; for (int j = 0; j < num_data_; ++j) { - host4[j].s0 = (iters[0].RawGet(j) * dev_bin_mult[0] + ((j+0) & (dev_bin_mult[0] - 1))) - |((iters[1].RawGet(j) * dev_bin_mult[1] + ((j+1) & (dev_bin_mult[1] - 1))) << 4); - host4[j].s1 = (iters[2].RawGet(j) * dev_bin_mult[2] + ((j+2) & (dev_bin_mult[2] - 1))) - |((iters[3].RawGet(j) * dev_bin_mult[3] + ((j+3) & (dev_bin_mult[3] - 1))) << 4); - host4[j].s2 = (iters[4].RawGet(j) * dev_bin_mult[4] + ((j+4) & (dev_bin_mult[4] - 1))) - |((iters[5].RawGet(j) * dev_bin_mult[5] + ((j+5) & (dev_bin_mult[5] - 1))) << 4); - host4[j].s3 = (iters[6].RawGet(j) * dev_bin_mult[6] + ((j+6) & (dev_bin_mult[6] - 1))) - |((iters[7].RawGet(j) * dev_bin_mult[7] + ((j+7) & (dev_bin_mult[7] - 1))) << 4); + host4[j].s[0] = (uint8_t)((iters[0].RawGet(j) * dev_bin_mult[0] + ((j+0) & (dev_bin_mult[0] - 1))) + |((iters[1].RawGet(j) * dev_bin_mult[1] + ((j+1) & (dev_bin_mult[1] - 1))) << 4)); + host4[j].s[1] = (uint8_t)((iters[2].RawGet(j) * dev_bin_mult[2] + ((j+2) & (dev_bin_mult[2] - 1))) + |((iters[3].RawGet(j) * dev_bin_mult[3] + ((j+3) & (dev_bin_mult[3] - 1))) << 4)); + host4[j].s[2] = (uint8_t)((iters[4].RawGet(j) * dev_bin_mult[4] + ((j+4) & (dev_bin_mult[4] - 1))) + |((iters[5].RawGet(j) * dev_bin_mult[5] + ((j+5) & (dev_bin_mult[5] - 1))) << 4)); + host4[j].s[3] = (uint8_t)((iters[6].RawGet(j) * dev_bin_mult[6] + ((j+6) & (dev_bin_mult[6] - 1))) + |((iters[7].RawGet(j) * dev_bin_mult[7] + ((j+7) & (dev_bin_mult[7] - 1))) << 4)); } } else if (dword_features_ == 4) { @@ -420,14 +421,14 @@ void GPUTreeLearner::AllocateGPUMemory() { // Dense bin DenseBinIterator iter = *static_cast*>(bin_iter); for (int j = 0; j < num_data_; ++j) { - host4[j].s[s_idx] = iter.RawGet(j) * dev_bin_mult[s_idx] + ((j+s_idx) & (dev_bin_mult[s_idx] - 1)); + host4[j].s[s_idx] = (uint8_t)(iter.RawGet(j) * dev_bin_mult[s_idx] + ((j+s_idx) & (dev_bin_mult[s_idx] - 1))); } } else if (dynamic_cast(bin_iter) != 0) { // Dense 4-bit bin Dense4bitsBinIterator iter = *static_cast(bin_iter); for (int j = 0; j < num_data_; ++j) { - host4[j].s[s_idx] = iter.RawGet(j) * dev_bin_mult[s_idx] + ((j+s_idx) & (dev_bin_mult[s_idx] - 1)); + host4[j].s[s_idx] = (uint8_t)(iter.RawGet(j) * dev_bin_mult[s_idx] + ((j+s_idx) & (dev_bin_mult[s_idx] - 1))); } } else { @@ -458,38 +459,38 @@ void GPUTreeLearner::AllocateGPUMemory() { #if GPU_DEBUG >= 1 printf("%d features left\n", k); #endif - for (i = 0; i < k; ++i) { + for (int i = 0; i < k; ++i) { if (dword_features_ == 8) { - BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_ind[i]); + BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_dword_ind[i]); if (dynamic_cast(bin_iter) != 0) { Dense4bitsBinIterator iter = *static_cast(bin_iter); #pragma omp parallel for schedule(static) for (int j = 0; j < num_data_; ++j) { - host4[j].s[i >> 1] |= ((iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i] + host4[j].s[i >> 1] |= (uint8_t)((iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i] + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1))) << ((i & 1) << 2)); } } else { - Log::Fatal("GPU tree learner assumes that all bins are Dense4bitsBin when num_bin <= 16, but feature %d is not.", dense_ind[i]); + Log::Fatal("GPU tree learner assumes that all bins are Dense4bitsBin when num_bin <= 16, but feature %d is not.", dense_dword_ind[i]); } } else if (dword_features_ == 4) { - BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_ind[i]); + BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_dword_ind[i]); if (dynamic_cast*>(bin_iter) != 0) { DenseBinIterator iter = *static_cast*>(bin_iter); #pragma omp parallel for schedule(static) for (int j = 0; j < num_data_; ++j) { - host4[j].s[i] = iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i] - + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1)); + host4[j].s[i] = (uint8_t)(iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i] + + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1))); } } else if (dynamic_cast(bin_iter) != 0) { Dense4bitsBinIterator iter = *static_cast(bin_iter); #pragma omp parallel for schedule(static) for (int j = 0; j < num_data_; ++j) { - host4[j].s[i] = iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i] - + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1)); + host4[j].s[i] = (uint8_t)(iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i] + + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1))); } } else { @@ -504,18 +505,18 @@ void GPUTreeLearner::AllocateGPUMemory() { if (dword_features_ == 8) { #pragma omp parallel for schedule(static) for (int j = 0; j < num_data_; ++j) { - for (i = k; i < dword_features_; ++i) { + for (int i = k; i < dword_features_; ++i) { // fill this empty feature with some "random" value - host4[j].s[i >> 1] |= ((j & 0xf) << ((i & 1) << 2)); + host4[j].s[i >> 1] |= (uint8_t)((j & 0xf) << ((i & 1) << 2)); } } } else if (dword_features_ == 4) { #pragma omp parallel for schedule(static) for (int j = 0; j < num_data_; ++j) { - for (i = k; i < dword_features_; ++i) { + for (int i = k; i < dword_features_; ++i) { // fill this empty feature with some "random" value - host4[j].s[i] = j; + host4[j].s[i] = (uint8_t)j; } } } @@ -525,8 +526,8 @@ void GPUTreeLearner::AllocateGPUMemory() { #if GPU_DEBUG >= 1 printf("Last features copied to device\n"); #endif - for (i = 0; i < k; ++i) { - dense_feature_group_map_.push_back(dense_ind[i]); + for (int i = 0; i < k; ++i) { + dense_feature_group_map_.push_back(dense_dword_ind[i]); } } // deallocate pinned space for feature copying @@ -542,12 +543,12 @@ void GPUTreeLearner::AllocateGPUMemory() { end_time * 1e-3, sparse_feature_group_map_.size()); #if GPU_DEBUG >= 1 printf("Dense feature group list (size %lu): ", dense_feature_group_map_.size()); - for (i = 0; i < num_dense_feature_groups_; ++i) { + for (int i = 0; i < num_dense_feature_groups_; ++i) { printf("%d ", dense_feature_group_map_[i]); } printf("\n"); printf("Sparse feature group list (size %lu): ", sparse_feature_group_map_.size()); - for (i = 0; i < num_feature_groups_ - num_dense_feature_groups_; ++i) { + for (int i = 0; i < num_feature_groups_ - num_dense_feature_groups_; ++i) { printf("%d ", sparse_feature_group_map_[i]); } printf("\n"); @@ -584,10 +585,10 @@ void GPUTreeLearner::BuildGPUKernels() { } catch (boost::compute::opencl_error &e) { if (program.build_log().size() > 0) { - Log::Fatal("GPU program built failure:\n %s", program.build_log().c_str()); + Log::Fatal("GPU program built failure: %s\n %s", e.what(), program.build_log().c_str()); } else { - Log::Fatal("GPU program built failure, log unavailable"); + Log::Fatal("GPU program built failure: %s\nlog unavailable", e.what()); } } histogram_kernels_[i] = program.create_kernel(kernel_name_); @@ -599,10 +600,10 @@ void GPUTreeLearner::BuildGPUKernels() { } catch (boost::compute::opencl_error &e) { if (program.build_log().size() > 0) { - Log::Fatal("GPU program built failure:\n %s", program.build_log().c_str()); + Log::Fatal("GPU program built failure: %s\n %s", e.what(), program.build_log().c_str()); } else { - Log::Fatal("GPU program built failure, log unavailable"); + Log::Fatal("GPU program built failure: %s\nlog unavailable", e.what()); } } histogram_allfeats_kernels_[i] = program.create_kernel(kernel_name_); @@ -614,10 +615,10 @@ void GPUTreeLearner::BuildGPUKernels() { } catch (boost::compute::opencl_error &e) { if (program.build_log().size() > 0) { - Log::Fatal("GPU program built failure:\n %s", program.build_log().c_str()); + Log::Fatal("GPU program built failure: %s\n %s", e.what(), program.build_log().c_str()); } else { - Log::Fatal("GPU program built failure, log unavailable"); + Log::Fatal("GPU program built failure: %s\nlog unavailable", e.what()); } } histogram_fulldata_kernels_[i] = program.create_kernel(kernel_name_); @@ -979,11 +980,11 @@ void GPUTreeLearner::ConstructHistograms(const std::vector& is_feature_u if (is_gpu_used) { if (tree_config_->gpu_use_dp) { // use double precision - WaitAndGetHistograms(ptr_smaller_leaf_hist_data, is_feature_used); + WaitAndGetHistograms(ptr_smaller_leaf_hist_data); } else { // use single precision - WaitAndGetHistograms(ptr_smaller_leaf_hist_data, is_feature_used); + WaitAndGetHistograms(ptr_smaller_leaf_hist_data); } } @@ -1033,11 +1034,11 @@ void GPUTreeLearner::ConstructHistograms(const std::vector& is_feature_u if (is_gpu_used) { if (tree_config_->gpu_use_dp) { // use double precision - WaitAndGetHistograms(ptr_larger_leaf_hist_data, is_feature_used); + WaitAndGetHistograms(ptr_larger_leaf_hist_data); } else { // use single precision - WaitAndGetHistograms(ptr_larger_leaf_hist_data, is_feature_used); + WaitAndGetHistograms(ptr_larger_leaf_hist_data); } } } diff --git a/src/treelearner/gpu_tree_learner.h b/src/treelearner/gpu_tree_learner.h index 9b7dd9d6112b..ae33a87b195e 100644 --- a/src/treelearner/gpu_tree_learner.h +++ b/src/treelearner/gpu_tree_learner.h @@ -64,15 +64,7 @@ class GPUTreeLearner: public SerialTreeLearner { private: /*! \brief 4-byte feature tuple used by GPU kernels */ struct Feature4 { - union { - unsigned char s[4]; - struct { - unsigned char s0; - unsigned char s1; - unsigned char s2; - unsigned char s3; - }; - }; + uint8_t s[4]; }; /*! \brief Single precision histogram entiry for GPU */ @@ -123,10 +115,9 @@ class GPUTreeLearner: public SerialTreeLearner { /*! * \brief Wait for GPU kernel execution and read histogram * \param histograms Destination of histogram results from GPU. - * \param is_feature_used A predicate vector for enabling each feature */ template - void WaitAndGetHistograms(HistogramBinEntry* histograms, const std::vector& is_feature_used); + void WaitAndGetHistograms(HistogramBinEntry* histograms); /*! * \brief Construct GPU histogram asynchronously. @@ -173,7 +164,7 @@ class GPUTreeLearner: public SerialTreeLearner { const char *kernel64_src_ = #include "ocl/histogram64.cl" ; - /*! \brief GPU kernel for 64 bins */ + /*! \brief GPU kernel for 16 bins */ const char *kernel16_src_ = #include "ocl/histogram16.cl" ; diff --git a/src/treelearner/ocl/histogram16.cl b/src/treelearner/ocl/histogram16.cl index dc393f60c440..b98f45366ee3 100644 --- a/src/treelearner/ocl/histogram16.cl +++ b/src/treelearner/ocl/histogram16.cl @@ -149,6 +149,10 @@ inline void atomic_local_add_f(__local acc_type *addr, const float val) #endif } +/* Makes MSVC happy with long string literal +)"" +R""() +*/ // this function will be called by histogram16 // we have one sub-histogram of one feature in registers, and need to read others void within_kernel_reduction16x8(uchar8 feature_mask, @@ -204,6 +208,11 @@ void within_kernel_reduction16x8(uchar8 feature_mask, } +/* Makes MSVC happy with long string literal +)"" +R""() +*/ + __attribute__((reqd_work_group_size(LOCAL_SIZE_0, 1, 1))) #if USE_CONSTANT_BUF == 1 __kernel void histogram16(__global const uchar4* restrict feature_data_base, @@ -379,6 +388,10 @@ __kernel void histogram16(__global const uchar4* feature_data_base, #endif feature4 = feature_data[ind]; +/* Makes MSVC happy with long string literal +)"" +R""() +*/ // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4 for (uint i = subglobal_tid; i < num_data; i += subglobal_size) { // prefetch the next iteration variables @@ -601,6 +614,11 @@ __kernel void histogram16(__global const uchar4* feature_data_base, feature4 = feature4_next; } barrier(CLK_LOCAL_MEM_FENCE); + +/* Makes MSVC happy with long string literal +)"" +R""() +*/ #if ENABLE_ALL_FEATURES == 0 // restore feature_mask diff --git a/src/treelearner/ocl/histogram256.cl b/src/treelearner/ocl/histogram256.cl index 44471849fe6f..8d750aea5561 100644 --- a/src/treelearner/ocl/histogram256.cl +++ b/src/treelearner/ocl/histogram256.cl @@ -125,6 +125,10 @@ inline void atomic_local_add_f(__local acc_type *addr, const float val) #endif } +/* Makes MSVC happy with long string literal +)"" +R""() +*/ // this function will be called by histogram256 // we have one sub-histogram of one feature in local memory, and need to read others void within_kernel_reduction256x4(uchar4 feature_mask, @@ -332,8 +336,10 @@ void within_kernel_reduction256x4(uchar4 feature_mask, #endif } -#define printf - +/* Makes MSVC happy with long string literal +)"" +R""() +*/ __attribute__((reqd_work_group_size(LOCAL_SIZE_0, 1, 1))) #if USE_CONSTANT_BUF == 1 __kernel void histogram256(__global const uchar4* restrict feature_data_base, @@ -453,6 +459,11 @@ __kernel void histogram256(__global const uchar4* feature_data_base, acc_type s1_stat1 = 0.0f, s1_stat2 = 0.0f; acc_type s0_stat1 = 0.0f, s0_stat2 = 0.0f; + +/* Makes MSVC happy with long string literal +)"" +R""() +*/ // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4 for (uint i = subglobal_tid; i < num_data; i += subglobal_size) { // prefetch the next iteration variables @@ -659,6 +670,10 @@ __kernel void histogram256(__global const uchar4* feature_data_base, #endif barrier(CLK_LOCAL_MEM_FENCE); +/* Makes MSVC happy with long string literal +)"" +R""() +*/ #if ENABLE_ALL_FEATURES == 0 // restore feature_mask feature_mask = feature_masks[group_feature]; diff --git a/src/treelearner/ocl/histogram64.cl b/src/treelearner/ocl/histogram64.cl index 661f764d4b63..2e8d03b7be93 100644 --- a/src/treelearner/ocl/histogram64.cl +++ b/src/treelearner/ocl/histogram64.cl @@ -142,6 +142,11 @@ inline void atomic_local_add_f(__local acc_type *addr, const float val) #endif } + +/* Makes MSVC happy with long string literal +)"" +R""() +*/ // this function will be called by histogram64 // we have one sub-histogram of one feature in registers, and need to read others void within_kernel_reduction64x4(uchar4 feature_mask, @@ -199,8 +204,10 @@ void within_kernel_reduction64x4(uchar4 feature_mask, } } -#define printf - +/* Makes MSVC happy with long string literal +)"" +R""() +*/ __attribute__((reqd_work_group_size(LOCAL_SIZE_0, 1, 1))) #if USE_CONSTANT_BUF == 1 __kernel void histogram64(__global const uchar4* restrict feature_data_base, @@ -360,6 +367,10 @@ __kernel void histogram64(__global const uchar4* feature_data_base, acc_type s1_stat1 = 0.0f, s1_stat2 = 0.0f; acc_type s0_stat1 = 0.0f, s0_stat2 = 0.0f; +/* Makes MSVC happy with long string literal +)"" +R""() +*/ // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4 for (uint i = subglobal_tid; i < num_data; i += subglobal_size) { // prefetch the next iteration variables @@ -567,6 +578,10 @@ __kernel void histogram64(__global const uchar4* feature_data_base, #endif barrier(CLK_LOCAL_MEM_FENCE); +/* Makes MSVC happy with long string literal +)"" +R""() +*/ #if ENABLE_ALL_FEATURES == 0 // restore feature_mask feature_mask = feature_masks[group_feature];