Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge feature into main branch #1320

Merged
merged 11 commits into from
Feb 9, 2025
Prev Previous commit
Next Next commit
Add class boosting::NonDecomposableStatisticsState.
  • Loading branch information
michael-rapp committed Feb 9, 2025
commit e7ea32824ec302a3c0da68a2a1995995c604ba3d
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

/*
* @author Michael Rapp (michael.rapp.ml@gmail.com)
*/
#pragma once

#include "statistics_state.hpp"

namespace boosting {

/**
* Represents the current state of a sequential boosting process, which uses a non-decomposable loss function for
* calculating gradients and Hessians, and allows to update it.
*
* @tparam OutputMatrix The type of the matrix that provides access to the ground truth of the training examples
* @tparam StatisticMatrix The type of the matrix that stores the gradients and Hessians
* @tparam ScoreMatrix The type of the matrices that are used to store predicted scores
* @tparam LossFunction The type of the loss function that is used to calculate gradients and Hessians
*/
template<typename OutputMatrix, typename StatisticMatrix, typename ScoreMatrix, typename LossFunction>
class NonDecomposableStatisticsState final
: public AbstractStatisticsState<OutputMatrix, StatisticMatrix, ScoreMatrix, LossFunction> {
public:

/**
* @param outputMatrix A reference to an object of template type `OutputMatrix` that
* provides access to the ground truth of the training examples
* @param statisticMatrixPtr An unique pointer to an object of template type `StatisticMatrix`
* that provides access to the gradients and Hessians
* @param scoreMatrixPtr An unique pointer to an object of template type `ScoreMatrix` that
* stores the currently predicted scores
* @param lossFunctionPtr An unique pointer to an object of template type `LossFunction` that
* implements the loss function that should be used for calculating
* gradients and Hessians
*/
NonDecomposableStatisticsState(const OutputMatrix& outputMatrix,
std::unique_ptr<StatisticMatrix> statisticMatrixPtr,
std::unique_ptr<ScoreMatrix> scoreMatrixPtr,
std::unique_ptr<LossFunction> lossFunctionPtr)
: AbstractStatisticsState<OutputMatrix, StatisticMatrix, ScoreMatrix, LossFunction>(
outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr),
std::move(lossFunctionPtr)) {}

void updateStatistics(uint32 statisticIndex, CompleteIndexVector::const_iterator indicesBegin,
CompleteIndexVector::const_iterator indicesEnd) override {
this->lossFunctionPtr->updateNonDecomposableStatistics(statisticIndex, this->outputMatrix,
this->scoreMatrixPtr->getView(),
this->statisticMatrixPtr->getView());
}

void updateStatistics(uint32 statisticIndex, PartialIndexVector::const_iterator indicesBegin,
PartialIndexVector::const_iterator indicesEnd) override {
this->lossFunctionPtr->updateNonDecomposableStatistics(statisticIndex, this->outputMatrix,
this->scoreMatrixPtr->getView(),
this->statisticMatrixPtr->getView());
}
};

}