Skip to content

Commit

Permalink
[CUDA] Add rmse metric for new CUDA version (microsoft#5611)
Browse files Browse the repository at this point in the history
* add rmse metric for new cuda version

* add Init for CUDAMetricInterface

* fix lint errors
  • Loading branch information
shiyu1994 authored Dec 2, 2022
1 parent 38a1f58 commit f0cfbff
Show file tree
Hide file tree
Showing 17 changed files with 230 additions and 25 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ endif()
if(USE_CUDA_EXP)
src/boosting/cuda/*.cpp
src/boosting/cuda/*.cu
src/metric/cuda/*.cpp
src/metric/cuda/*.cu
src/objective/cuda/*.cpp
src/objective/cuda/*.cu
src/treelearner/cuda/*.cpp
Expand Down
41 changes: 41 additions & 0 deletions include/LightGBM/cuda/cuda_metric.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifndef LIGHTGBM_CUDA_CUDA_METRIC_HPP_
#define LIGHTGBM_CUDA_CUDA_METRIC_HPP_

#ifdef USE_CUDA_EXP

#include <LightGBM/metric.h>

namespace LightGBM {

template <typename HOST_METRIC>
class CUDAMetricInterface: public HOST_METRIC {
public:
explicit CUDAMetricInterface(const Config& config): HOST_METRIC(config) {
cuda_labels_ = nullptr;
cuda_weights_ = nullptr;
}

void Init(const Metadata& metadata, data_size_t num_data) override {
HOST_METRIC::Init(metadata, num_data);
cuda_labels_ = metadata.cuda_metadata()->cuda_label();
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
}

bool IsCUDAMetric() const { return true; }

protected:
const label_t* cuda_labels_;
const label_t* cuda_weights_;
};

} // namespace LightGBM

#endif // USE_CUDA_EXP

#endif // LIGHTGBM_CUDA_CUDA_METRIC_HPP_
12 changes: 3 additions & 9 deletions include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,8 @@ class CUDAObjectiveInterface: public HOST_OBJECTIVE {
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
}

virtual void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}

std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
};
virtual const double* ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
return LaunchConvertOutputCUDAKernel(num_data, input, output);
}

double BoostFromScore(int class_id) const override {
Expand Down Expand Up @@ -67,7 +61,7 @@ class CUDAObjectiveInterface: public HOST_OBJECTIVE {
return HOST_OBJECTIVE::BoostFromScore(class_id);
}

virtual void LaunchConvertOutputCUDAKernel(const data_size_t /*num_data*/, const double* /*input*/, double* /*output*/) const {}
virtual const double* LaunchConvertOutputCUDAKernel(const data_size_t /*num_data*/, const double* input, double* /*output*/) const { return input; }

virtual void LaunchRenewTreeOutputCUDAKernel(
const double* /*score*/, const data_size_t* /*data_indices_in_leaf*/, const data_size_t* /*num_data_in_leaf*/,
Expand Down
6 changes: 3 additions & 3 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ class ObjectiveFunction {

#ifdef USE_CUDA_EXP
/*!
* \brief Get output convert function for CUDA version
* \brief Convert output for CUDA version
*/
virtual std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const {
return [] (data_size_t, const double*, double*) {};
const double* ConvertOutputCUDA(data_size_t /*num_data*/, const double* input, double* /*output*/) const {
return input;
}
#endif // USE_CUDA_EXP
};
Expand Down
1 change: 1 addition & 0 deletions src/cuda/cuda_algorithms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void ShuffleReduceSumGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffe
}

template void ShuffleReduceSumGlobal<label_t, double>(const label_t* values, size_t n, double* block_buffer);
template void ShuffleReduceSumGlobal<double, double>(const double* values, size_t n, double* block_buffer);

template <typename VAL_T, typename REDUCE_T>
__global__ void ShuffleReduceMinGlobalKernel(const VAL_T* values, const data_size_t num_value, REDUCE_T* block_buffer) {
Expand Down
43 changes: 43 additions & 0 deletions src/metric/cuda/cuda_regression_metric.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright (c) 2022 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifdef USE_CUDA_EXP

#include <vector>

#include "cuda_regression_metric.hpp"

namespace LightGBM {

template <typename HOST_METRIC, typename CUDA_METRIC>
void CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>::Init(const Metadata& metadata, data_size_t num_data) {
CUDAMetricInterface<HOST_METRIC>::Init(metadata, num_data);
const int max_num_reduce_blocks = (this->num_data_ + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD;
if (this->cuda_weights_ == nullptr) {
reduce_block_buffer_.Resize(max_num_reduce_blocks);
} else {
reduce_block_buffer_.Resize(max_num_reduce_blocks * 2);
}
const int max_num_reduce_blocks_inner = (max_num_reduce_blocks + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD;
if (this->cuda_weights_ == nullptr) {
reduce_block_buffer_inner_.Resize(max_num_reduce_blocks_inner);
} else {
reduce_block_buffer_inner_.Resize(max_num_reduce_blocks_inner * 2);
}
}

template <typename HOST_METRIC, typename CUDA_METRIC>
std::vector<double> CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>::Eval(const double* score, const ObjectiveFunction* objective) const {
const double* score_convert = objective->ConvertOutputCUDA(this->num_data_, score, score_convert_buffer_.RawData());
const double eval_score = LaunchEvalKernel(score_convert);
return std::vector<double>{eval_score};
}

CUDARMSEMetric::CUDARMSEMetric(const Config& config): CUDARegressionMetricInterface<RMSEMetric, CUDARMSEMetric>(config) {}

} // namespace LightGBM

#endif // USE_CUDA_EXP
61 changes: 61 additions & 0 deletions src/metric/cuda/cuda_regression_metric.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*!
* Copyright (c) 2022 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_algorithms.hpp>

#include "cuda_regression_metric.hpp"

namespace LightGBM {

template <typename CUDA_METRIC, bool USE_WEIGHTS>
__global__ void EvalKernel(const data_size_t num_data, const label_t* labels, const label_t* weights,
const double* scores, double* reduce_block_buffer) {
__shared__ double shared_mem_buffer[32];
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
double point_metric = 0.0;
if (index < num_data) {
point_metric = CUDA_METRIC::MetricOnPointCUDA(labels[index], scores[index]);
}
const double block_sum_point_metric = ShuffleReduceSum<double>(point_metric, shared_mem_buffer, NUM_DATA_PER_EVAL_THREAD);
reduce_block_buffer[blockIdx.x] = block_sum_point_metric;
if (USE_WEIGHTS) {
double weight = 0.0;
if (index < num_data) {
weight = static_cast<double>(weights[index]);
const double block_sum_weight = ShuffleReduceSum<double>(weight, shared_mem_buffer, NUM_DATA_PER_EVAL_THREAD);
reduce_block_buffer[blockIdx.x + blockDim.x] = block_sum_weight;
}
}
}

template <typename HOST_METRIC, typename CUDA_METRIC>
double CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>::LaunchEvalKernel(const double* score) const {
const int num_blocks = (this->num_data_ + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD;
if (this->cuda_weights_ != nullptr) {
EvalKernel<CUDA_METRIC, true><<<num_blocks, NUM_DATA_PER_EVAL_THREAD>>>(
this->num_data_, this->cuda_labels_, this->cuda_weights_, score, reduce_block_buffer_.RawData());
} else {
EvalKernel<CUDA_METRIC, false><<<num_blocks, NUM_DATA_PER_EVAL_THREAD>>>(
this->num_data_, this->cuda_labels_, this->cuda_weights_, score, reduce_block_buffer_.RawData());
}
ShuffleReduceSumGlobal<double, double>(reduce_block_buffer_.RawData(), num_blocks, reduce_block_buffer_inner_.RawData());
double sum_loss = 0.0;
CopyFromCUDADeviceToHost<double>(&sum_loss, reduce_block_buffer_inner_.RawData(), 1, __FILE__, __LINE__);
double sum_weight = static_cast<double>(this->num_data_);
if (this->cuda_weights_ != nullptr) {
ShuffleReduceSumGlobal<double, double>(reduce_block_buffer_.RawData() + num_blocks, num_blocks, reduce_block_buffer_inner_.RawData());
CopyFromCUDADeviceToHost<double>(&sum_weight, reduce_block_buffer_inner_.RawData(), 1, __FILE__, __LINE__);
}
return this->AverageLoss(sum_loss, sum_weight);
}

template double CUDARegressionMetricInterface<RMSEMetric, CUDARMSEMetric>::LaunchEvalKernel(const double* score) const;

} // namespace LightGBM

#endif // USE_CUDA_EXP
57 changes: 57 additions & 0 deletions src/metric/cuda/cuda_regression_metric.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*!
* Copyright (c) 2022 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifndef LIGHTGBM_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_
#define LIGHTGBM_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>

#include <vector>

#include "../regression_metric.hpp"

#define NUM_DATA_PER_EVAL_THREAD (1024)

namespace LightGBM {

template <typename HOST_METRIC, typename CUDA_METRIC>
class CUDARegressionMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
public:
explicit CUDARegressionMetricInterface(const Config& config): CUDAMetricInterface<HOST_METRIC>(config) {}

virtual ~CUDARegressionMetricInterface() {}

void Init(const Metadata& metadata, data_size_t num_data) override;

std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override;

protected:
double LaunchEvalKernel(const double* score_convert) const;

CUDAVector<double> score_convert_buffer_;
CUDAVector<double> reduce_block_buffer_;
CUDAVector<double> reduce_block_buffer_inner_;
};

class CUDARMSEMetric: public CUDARegressionMetricInterface<RMSEMetric, CUDARMSEMetric> {
public:
explicit CUDARMSEMetric(const Config& config);

virtual ~CUDARMSEMetric() {}

__device__ static double MetricOnPointCUDA(label_t label, double score) {
return (score - static_cast<double>(label));
}
};

} // namespace LightGBM

#endif // USE_CUDA_EXP

#endif // LIGHTGBM_METRIC_CUDA_CUDA_REGRESSION_METRIC_HPP_
5 changes: 3 additions & 2 deletions src/metric/metric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "regression_metric.hpp"
#include "xentropy_metric.hpp"

#include "cuda/cuda_regression_metric.hpp"

namespace LightGBM {

Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
Expand All @@ -20,8 +22,7 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
Log::Warning("Metric l2 is not implemented in cuda_exp version. Fall back to evaluation on CPU.");
return new L2Metric(config);
} else if (type == std::string("rmse")) {
Log::Warning("Metric rmse is not implemented in cuda_exp version. Fall back to evaluation on CPU.");
return new RMSEMetric(config);
return new CUDARMSEMetric(config);
} else if (type == std::string("l1")) {
Log::Warning("Metric l1 is not implemented in cuda_exp version. Fall back to evaluation on CPU.");
return new L1Metric(config);
Expand Down
2 changes: 1 addition & 1 deletion src/metric/regression_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class RegressionMetric: public Metric {
inline static void CheckLabel(label_t) {
}

private:
protected:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Pointer of label */
Expand Down
3 changes: 2 additions & 1 deletion src/objective/cuda/cuda_binary_objective.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ __global__ void ConvertOutputCUDAKernel_BinaryLogloss(const double sigmoid, cons
}
}

void CUDABinaryLogloss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const double* CUDABinaryLogloss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY;
ConvertOutputCUDAKernel_BinaryLogloss<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(sigmoid_, num_data, input, output);
return output;
}

__global__ void ResetOVACUDALabelKernel(
Expand Down
2 changes: 1 addition & 1 deletion src/objective/cuda/cuda_binary_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CUDABinaryLogloss : public CUDAObjectiveInterface<BinaryLogloss> {

double LaunchCalcInitScoreKernel(const int class_id) const override;

void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;
const double* LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;

void LaunchResetOVACUDALabelKernel() const;

Expand Down
3 changes: 2 additions & 1 deletion src/objective/cuda/cuda_multiclass_objective.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ void CUDAMulticlassOVA::GetGradients(const double* score, score_t* gradients, sc
}
}

void CUDAMulticlassOVA::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
const double* CUDAMulticlassOVA::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
for (int i = 0; i < num_class_; ++i) {
cuda_binary_loss_[i]->ConvertOutputCUDA(num_data, input + i * num_data, output + i * num_data);
}
return output;
}


Expand Down
3 changes: 2 additions & 1 deletion src/objective/cuda/cuda_multiclass_objective.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ __global__ void ConvertOutputCUDAKernel_MulticlassSoftmax(
}
}

void CUDAMulticlassSoftmax::LaunchConvertOutputCUDAKernel(
const double* CUDAMulticlassSoftmax::LaunchConvertOutputCUDAKernel(
const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_MULTICLASS - 1) / GET_GRADIENTS_BLOCK_SIZE_MULTICLASS;
ConvertOutputCUDAKernel_MulticlassSoftmax<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
num_class_, num_data, input, cuda_softmax_buffer_.RawData(), output);
return output;
}

} // namespace LightGBM
Expand Down
4 changes: 2 additions & 2 deletions src/objective/cuda/cuda_multiclass_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CUDAMulticlassSoftmax: public CUDAObjectiveInterface<MulticlassSoftmax> {
private:
void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const;

void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;
const double* LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;

// CUDA memory, held by this object
CUDAVector<double> cuda_softmax_buffer_;
Expand All @@ -51,7 +51,7 @@ class CUDAMulticlassOVA: public CUDAObjectiveInterface<MulticlassOVA> {

void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override;

void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override;
const double* ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override;

double BoostFromScore(int class_id) const override {
return cuda_binary_loss_[class_id]->BoostFromScore(0);
Expand Down
6 changes: 4 additions & 2 deletions src/objective/cuda/cuda_regression_objective.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ __global__ void ConvertOutputCUDAKernel_Regression(const bool sqrt, const data_s
}
}

void CUDARegressionL2loss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const double* CUDARegressionL2loss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
ConvertOutputCUDAKernel_Regression<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(sqrt_, num_data, input, output);
return output;
}

template <bool USE_WEIGHT>
Expand Down Expand Up @@ -339,9 +340,10 @@ __global__ void ConvertOutputCUDAKernel_Regression_Poisson(const data_size_t num
}
}

void CUDARegressionPoissonLoss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const double* CUDARegressionPoissonLoss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
ConvertOutputCUDAKernel_Regression_Poisson<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(num_data, input, output);
return output;
}


Expand Down
4 changes: 2 additions & 2 deletions src/objective/cuda/cuda_regression_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CUDARegressionL2loss : public CUDARegressionObjectiveInterface<RegressionL
protected:
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const override;

void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;
const double* LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;
};


Expand Down Expand Up @@ -121,7 +121,7 @@ class CUDARegressionPoissonLoss : public CUDARegressionObjectiveInterface<Regres
protected:
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const override;

void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;
const double* LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;

double LaunchCalcInitScoreKernel(const int class_id) const override;

Expand Down

0 comments on commit f0cfbff

Please sign in to comment.