Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add template arguments to data structures storing statistics #1263

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add template argument to class DenseNonDecomposableStatisticView.
  • Loading branch information
michael-rapp committed Jan 20, 2025
commit 486e665cd7e79e5451c455bac42b52156ea27010
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ namespace boosting {
* gradients and Hessians to be added to this vector
* @param row The index of the row to be added to this vector
*/
void add(const DenseNonDecomposableStatisticView& view, uint32 row);
void add(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row);

/**
* Adds all gradients and Hessians in a single row of a `DenseNonDecomposableStatisticView` to this vector.
Expand All @@ -175,7 +175,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void add(const DenseNonDecomposableStatisticView& view, uint32 row, StatisticType weight);
void add(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight);

/**
* Removes all gradients and Hessians in a single row of a `DenseNonDecomposableStatisticView` from this
Expand All @@ -185,7 +185,7 @@ namespace boosting {
* gradients and Hessians to be removed from this vector
* @param row The index of the row to be removed from this vector
*/
void remove(const DenseNonDecomposableStatisticView& view, uint32 row);
void remove(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row);

/**
* Removes all gradients and Hessians in a single row of a `DenseNonDecomposableStatisticView` from this
Expand All @@ -196,7 +196,7 @@ namespace boosting {
* @param row The index of the row to be removed from this vector
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void remove(const DenseNonDecomposableStatisticView& view, uint32 row, StatisticType weight);
void remove(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight);

/**
* Adds certain gradients and Hessians in another vector, whose positions are given as a
Expand All @@ -207,7 +207,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param indices A reference to a `CompleteIndexVector` that provides access to the indices
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const CompleteIndexVector& indices);

/**
Expand All @@ -219,7 +219,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param indices A reference to a `PartialIndexVector` that provides access to the indices
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const PartialIndexVector& indices);

/**
Expand All @@ -233,7 +233,7 @@ namespace boosting {
* @param indices A reference to a `CompleteIndexVector` that provides access to the indices
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const CompleteIndexVector& indices, StatisticType weight);

/**
Expand All @@ -247,7 +247,7 @@ namespace boosting {
* @param indices A reference to a `PartialIndexVector` that provides access to the indices
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const PartialIndexVector& indices, StatisticType weight);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ namespace boosting {
/**
* Implements row-wise read and write access to the gradients and Hessians that have been calculated using a
* non-decomposable loss function and are stored in pre-allocated C-contiguous arrays.
*
* @tparam StatisticType The type of the gradients and Hessians
*/
template<typename StatisticType>
class MLRLBOOSTING_API DenseNonDecomposableStatisticView
: public CompositeMatrix<AllocatedCContiguousView<float64>, AllocatedCContiguousView<float64>> {
: public CompositeMatrix<AllocatedCContiguousView<StatisticType>, AllocatedCContiguousView<StatisticType>> {
public:

/**
Expand All @@ -27,34 +30,34 @@ namespace boosting {
/**
* @param other A reference to an object of type `DenseNonDecomposableStatisticView` that should be copied
*/
DenseNonDecomposableStatisticView(DenseNonDecomposableStatisticView&& other);
DenseNonDecomposableStatisticView(DenseNonDecomposableStatisticView<StatisticType>&& other);

virtual ~DenseNonDecomposableStatisticView() override {}

/**
* An iterator that provides read-only access to the gradients.
*/
typedef AllocatedCContiguousView<float64>::value_const_iterator gradient_const_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_const_iterator gradient_const_iterator;

/**
* An iterator that provides access to the gradients and allows to modify them.
*/
typedef AllocatedCContiguousView<float64>::value_iterator gradient_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_iterator gradient_iterator;

/**
* An iterator that provides read-only access to the Hessians.
*/
typedef AllocatedCContiguousView<float64>::value_const_iterator hessian_const_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_const_iterator hessian_const_iterator;

/**
* An iterator that provides access to the Hessians and allows to modify them.
*/
typedef AllocatedCContiguousView<float64>::value_iterator hessian_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_iterator hessian_iterator;

/**
* An iterator that provides read-only access to the Hessians that correspond to the diagonal of the matrix.
*/
typedef DiagonalIterator<const float64> hessian_diagonal_const_iterator;
typedef DiagonalIterator<const StatisticType> hessian_diagonal_const_iterator;

/**
* Returns a `gradient_const_iterator` to the beginning of the gradients at a specific row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ namespace boosting {
* predicted scores
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex,
const CContiguousView<const uint8>& labelMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const CContiguousView<const uint8>& labelMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;

/**
* Updates the statistics of the example at a specific index.
Expand All @@ -44,9 +44,9 @@ namespace boosting {
* predicted scores
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex, const BinaryCsrView& labelMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const BinaryCsrView& labelMatrix, const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;
};

/**
Expand All @@ -69,10 +69,10 @@ namespace boosting {
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be
* updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex,
const CContiguousView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const CContiguousView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;

/**
* Updates the statistics of the example at a specific index.
Expand All @@ -85,10 +85,10 @@ namespace boosting {
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be
* updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex,
const CsrView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const CsrView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,49 +100,48 @@ namespace boosting {
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::add(const DenseNonDecomposableStatisticView& view,
uint32 row) {
void DenseNonDecomposableStatisticVector<StatisticType>::add(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row) {
util::addToView(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients());
util::addToView(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians());
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::add(const DenseNonDecomposableStatisticView& view,
uint32 row, StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::add(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight) {
util::addToViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients(), weight);
util::addToViewWeighted(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians(), weight);
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::remove(const DenseNonDecomposableStatisticView& view,
uint32 row) {
void DenseNonDecomposableStatisticVector<StatisticType>::remove(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row) {
util::removeFromView(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients());
util::removeFromView(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians());
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::remove(const DenseNonDecomposableStatisticView& view,
uint32 row, StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::remove(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight) {
util::removeFromViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients(),
weight);
util::removeFromViewWeighted(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians(), weight);
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row,
const CompleteIndexVector& indices) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const CompleteIndexVector& indices) {
util::addToView(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients());
util::addToView(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians());
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row,
const PartialIndexVector& indices) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const PartialIndexVector& indices) {
PartialIndexVector::const_iterator indexIterator = indices.cbegin();
util::addToView(this->gradients_begin(), view.gradients_cbegin(row), indexIterator, this->getNumGradients());
DenseNonDecomposableStatisticView::hessian_const_iterator hessiansBegin = view.hessians_cbegin(row);
typename DenseNonDecomposableStatisticView<StatisticType>::hessian_const_iterator hessiansBegin =
view.hessians_cbegin(row);

for (uint32 i = 0; i < this->getNumGradients(); i++) {
uint32 index = indexIterator[i];
Expand All @@ -152,21 +151,22 @@ namespace boosting {
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row, const CompleteIndexVector& indices,
StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const CompleteIndexVector& indices,
StatisticType weight) {
util::addToViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients(), weight);
util::addToViewWeighted(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians(), weight);
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row, const PartialIndexVector& indices,
StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const PartialIndexVector& indices,
StatisticType weight) {
PartialIndexVector::const_iterator indexIterator = indices.cbegin();
util::addToViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), indexIterator,
this->getNumGradients(), weight);
DenseNonDecomposableStatisticView::hessian_const_iterator hessiansBegin = view.hessians_cbegin(row);
typename DenseNonDecomposableStatisticView<StatisticType>::hessian_const_iterator hessiansBegin =
view.hessians_cbegin(row);

for (uint32 i = 0; i < this->getNumGradients(); i++) {
uint32 index = indexIterator[i];
Expand Down
Loading