Skip to content

Commit

Permalink
collapse the bins if the caller requests a monotone direction on a no…
Browse files Browse the repository at this point in the history
…minal or multiclass target
  • Loading branch information
paulbkoch committed Dec 19, 2024
1 parent c7bc2cd commit f7e26ee
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 2 additions & 1 deletion shared/libebm/GenerateTermUpdate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,8 @@ EBM_API_BODY ErrorEbm EBM_CALLING_CONVENTION GenerateTermUpdate(void* rng,
// are going to remain having 0 splits.
pBoosterShell->GetInnerTermUpdate()->Reset();

if(IntEbm{0} == lastDimensionLeavesMax || (1 != cRealDimensions && MONOTONE_NONE != monotoneDirection)) {
if(IntEbm{0} == lastDimensionLeavesMax ||
((1 != cRealDimensions || bNominal || 1 != cScores) && MONOTONE_NONE != monotoneDirection)) {
// this is kind of hacky where if any one of a number of things occurs (like we have only 1 leaf)
// we sum everything into a single bin. The alternative would be to always sum into the tensor bins
// but then collapse them afterwards into a single bin, but that's more work.
Expand Down
4 changes: 4 additions & 0 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,9 +779,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
EBM_ASSERT(2 <= cBins); // filter these out at the start where we can handle this case easily
EBM_ASSERT(1 <= cSplitsMax); // filter these out at the start where we can handle this case easily
EBM_ASSERT(nullptr != pTotalGain);
EBM_ASSERT(!bNominal || MONOTONE_NONE == monotoneDirection);

BoosterCore* const pBoosterCore = pBoosterShell->GetBoosterCore();
const size_t cScores = GET_COUNT_SCORES(cCompilerScores, pBoosterCore->GetCountScores());

EBM_ASSERT(1 == cScores || MONOTONE_NONE == monotoneDirection);

const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true, true, bHessian, cScores);
auto* const pRootTreeNode = pBoosterShell->GetTreeNodesTemp<bHessian, GetArrayScores(cCompilerScores)>();
pRootTreeNode->Init();
Expand Down

0 comments on commit f7e26ee

Please sign in to comment.