Skip to content

Commit a05a0ff

Browse files
authored
[ML] Handle unseen categories in encoding (#603)
Backport #602.
1 parent 21f42f7 commit a05a0ff

9 files changed

+252
-80
lines changed

include/core/CDataFrame.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@ class CORE_EXPORT CDataFrame final {
433433
std::size_t numberRows,
434434
std::size_t numberColumns);
435435

436-
// TODO We may want an architecture agnostic check pointing mechanism for long
437-
// running tasks.
436+
//! Get the value to use for a missing element in a data frame.
437+
static double valueOfMissing();
438438

439439
private:
440440
using TSizeSizePr = std::pair<std::size_t, std::size_t>;

include/maths/CDataFrameCategoryEncoder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ class MATHS_EXPORT CDataFrameCategoryEncoder final {
219219
TSizeVecVec m_OneHotEncodedCategories;
220220
TSizeUSetVec m_RareCategories;
221221
TDoubleVecVec m_CategoryFrequencies;
222-
TDoubleVecVec m_TargetMeanValues;
222+
TDoubleVec m_MeanCategoryFrequencies;
223+
TDoubleVecVec m_CategoryTargetMeanValues;
224+
TDoubleVec m_MeanCategoryTargetMeanValues;
223225
TDoubleVec m_FeatureVectorMics;
224226
TSizeVec m_FeatureVectorColumnMap;
225227
TSizeVec m_FeatureVectorEncodingMap;

lib/api/CDataFrameAnalyzer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,10 @@ void CDataFrameAnalyzer::addRowToDataFrame(const TStrVec& fieldValues) {
319319
double value;
320320
if (fieldValue.empty()) {
321321
++m_MissingValueCount;
322-
return core::CFloatStorage{std::numeric_limits<float>::quiet_NaN()};
322+
return core::CFloatStorage{core::CDataFrame::valueOfMissing()};
323323
} else if (core::CStringUtils::stringToTypeSilent(fieldValue, value) == false) {
324324
++m_BadValueCount;
325-
return core::CFloatStorage{std::numeric_limits<float>::quiet_NaN()};
325+
return core::CFloatStorage{core::CDataFrame::valueOfMissing()};
326326
}
327327

328328
// Tuncation is very unlikely since the values will typically be

lib/core/CDataFrame.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <algorithm>
1818
#include <future>
19+
#include <limits>
1920
#include <memory>
2021

2122
namespace ml {
@@ -261,6 +262,10 @@ std::size_t CDataFrame::estimateMemoryUsage(bool inMainMemory,
261262
return inMainMemory ? numberRows * numberColumns * sizeof(float) : 0;
262263
}
263264

265+
double CDataFrame::valueOfMissing() {
266+
return std::numeric_limits<double>::quiet_NaN();
267+
}
268+
264269
CDataFrame::TRowFuncVecBoolPr
265270
CDataFrame::parallelApplyToAllRows(std::size_t numberThreads,
266271
std::size_t beginRows,

lib/maths/CDataFrameCategoryEncoder.cc

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ const std::string COLUMN_USES_FREQUENCY_ENCODING_TAG{"uses_frequency_encoding"};
198198
const std::string ONE_HOT_ENCODED_CATEGORIES_TAG{"one_hot_encoded_categories"};
199199
const std::string RARE_CATEGORIES_TAG{"rare_categories"};
200200
const std::string CATEGORY_FREQUENCIES_TAG{"category_frequencies"};
201-
const std::string TARGET_MEAN_VALUES_TAG{"target_mean_values"};
201+
const std::string MEAN_CATEGORY_FREQUENCIES_TAG{"mean_category_frequencies"};
202+
const std::string CATEGORY_TARGET_MEAN_VALUES_TAG{"category_target_mean_values"};
203+
const std::string MEAN_CATEGORY_TARGET_MEAN_VALUES_TAG{"mean_category_target_mean_values"};
202204
const std::string FEATURE_VECTOR_MICS_TAG{"feature_vector_mics"};
203205
const std::string FEATURE_VECTOR_COLUMN_MAP_TAG{"feature_vector_column_map"};
204206
const std::string FEATURE_VECTOR_ENCODING_MAP_TAG{"feature_vector_encoding_map"};
@@ -376,9 +378,9 @@ bool CDataFrameCategoryEncoder::usesFrequencyEncoding(std::size_t feature) const
376378
}
377379

378380
double CDataFrameCategoryEncoder::frequency(std::size_t feature, std::size_t category) const {
379-
return this->usesOneHotEncoding(feature, category)
380-
? 0.0
381-
: m_CategoryFrequencies[feature][category];
381+
const auto& frequencies = m_CategoryFrequencies[feature];
382+
return category < frequencies.size() ? frequencies[category]
383+
: m_MeanCategoryFrequencies[feature];
382384
}
383385

384386
bool CDataFrameCategoryEncoder::isRareCategory(std::size_t feature, std::size_t category) const {
@@ -387,11 +389,9 @@ bool CDataFrameCategoryEncoder::isRareCategory(std::size_t feature, std::size_t
387389

388390
double CDataFrameCategoryEncoder::targetMeanValue(std::size_t feature,
389391
std::size_t category) const {
390-
// TODO combine rare categories and use one mapping for collections.
391-
return this->usesOneHotEncoding(feature, category) ||
392-
this->isRareCategory(feature, category)
393-
? 0.0
394-
: m_TargetMeanValues[feature][category];
392+
const auto& targetMeanValues = m_CategoryTargetMeanValues[feature];
393+
return category < targetMeanValues.size() ? targetMeanValues[category]
394+
: m_MeanCategoryTargetMeanValues[feature];
395395
}
396396

397397
std::uint64_t CDataFrameCategoryEncoder::checksum(std::uint64_t seed) const {
@@ -403,7 +403,9 @@ std::uint64_t CDataFrameCategoryEncoder::checksum(std::uint64_t seed) const {
403403
seed = CChecksum::calculate(seed, m_OneHotEncodedCategories);
404404
seed = CChecksum::calculate(seed, m_RareCategories);
405405
seed = CChecksum::calculate(seed, m_CategoryFrequencies);
406-
seed = CChecksum::calculate(seed, m_TargetMeanValues);
406+
seed = CChecksum::calculate(seed, m_MeanCategoryFrequencies);
407+
seed = CChecksum::calculate(seed, m_CategoryTargetMeanValues);
408+
seed = CChecksum::calculate(seed, m_MeanCategoryTargetMeanValues);
407409
seed = CChecksum::calculate(seed, m_FeatureVectorMics);
408410
seed = CChecksum::calculate(seed, m_FeatureVectorColumnMap);
409411
return CChecksum::calculate(seed, m_FeatureVectorEncodingMap);
@@ -422,7 +424,12 @@ void CDataFrameCategoryEncoder::acceptPersistInserter(core::CStatePersistInserte
422424
m_OneHotEncodedCategories, inserter);
423425
core::CPersistUtils::persist(RARE_CATEGORIES_TAG, m_RareCategories, inserter);
424426
core::CPersistUtils::persist(CATEGORY_FREQUENCIES_TAG, m_CategoryFrequencies, inserter);
425-
core::CPersistUtils::persist(TARGET_MEAN_VALUES_TAG, m_TargetMeanValues, inserter);
427+
core::CPersistUtils::persist(MEAN_CATEGORY_FREQUENCIES_TAG,
428+
m_MeanCategoryFrequencies, inserter);
429+
core::CPersistUtils::persist(CATEGORY_TARGET_MEAN_VALUES_TAG,
430+
m_CategoryTargetMeanValues, inserter);
431+
core::CPersistUtils::persist(MEAN_CATEGORY_TARGET_MEAN_VALUES_TAG,
432+
m_MeanCategoryTargetMeanValues, inserter);
426433
core::CPersistUtils::persist(FEATURE_VECTOR_MICS_TAG, m_FeatureVectorMics, inserter);
427434
core::CPersistUtils::persist(FEATURE_VECTOR_COLUMN_MAP_TAG,
428435
m_FeatureVectorColumnMap, inserter);
@@ -450,8 +457,15 @@ bool CDataFrameCategoryEncoder::acceptRestoreTraverser(core::CStateRestoreTraver
450457
RESTORE(CATEGORY_FREQUENCIES_TAG,
451458
core::CPersistUtils::restore(CATEGORY_FREQUENCIES_TAG,
452459
m_CategoryFrequencies, traverser))
453-
RESTORE(TARGET_MEAN_VALUES_TAG,
454-
core::CPersistUtils::restore(TARGET_MEAN_VALUES_TAG, m_TargetMeanValues, traverser))
460+
RESTORE(MEAN_CATEGORY_FREQUENCIES_TAG,
461+
core::CPersistUtils::restore(MEAN_CATEGORY_FREQUENCIES_TAG,
462+
m_MeanCategoryFrequencies, traverser))
463+
RESTORE(CATEGORY_TARGET_MEAN_VALUES_TAG,
464+
core::CPersistUtils::restore(CATEGORY_TARGET_MEAN_VALUES_TAG,
465+
m_CategoryTargetMeanValues, traverser))
466+
RESTORE(MEAN_CATEGORY_TARGET_MEAN_VALUES_TAG,
467+
core::CPersistUtils::restore(MEAN_CATEGORY_TARGET_MEAN_VALUES_TAG,
468+
m_MeanCategoryTargetMeanValues, traverser))
455469
RESTORE(FEATURE_VECTOR_MICS_TAG,
456470
core::CPersistUtils::restore(FEATURE_VECTOR_MICS_TAG,
457471
m_FeatureVectorMics, traverser))
@@ -483,7 +497,7 @@ CDataFrameCategoryEncoder::mics(std::size_t numberThreads,
483497
encoderFactories[E_TargetMean] = std::make_pair(
484498
[this](std::size_t column, std::size_t sampleColumn, std::size_t) {
485499
return std::make_unique<CDataFrameUtils::CTargetMeanCategoricalColumnValue>(
486-
sampleColumn, m_RareCategories[column], m_TargetMeanValues[column]);
500+
sampleColumn, m_RareCategories[column], m_CategoryTargetMeanValues[column]);
487501
},
488502
0.0);
489503
encoderFactories[E_Frequency] = std::make_pair(
@@ -531,8 +545,13 @@ void CDataFrameCategoryEncoder::setupFrequencyEncoding(std::size_t numberThreads
531545
LOG_TRACE(<< "category frequencies = "
532546
<< core::CContainerPrinter::print(m_CategoryFrequencies));
533547

534-
m_RareCategories.resize(frame.numberColumns());
548+
m_MeanCategoryFrequencies.resize(m_CategoryFrequencies.size());
549+
m_RareCategories.resize(m_CategoryFrequencies.size());
535550
for (std::size_t i = 0; i < m_CategoryFrequencies.size(); ++i) {
551+
m_MeanCategoryFrequencies[i] =
552+
m_CategoryFrequencies[i].empty()
553+
? 1.0
554+
: 1.0 / static_cast<double>(m_CategoryFrequencies[i].size());
536555
for (std::size_t j = 0; j < m_CategoryFrequencies[i].size(); ++j) {
537556
std::size_t count{static_cast<std::size_t>(
538557
m_CategoryFrequencies[i][j] * static_cast<double>(frame.numberRows()) + 0.5)};
@@ -541,6 +560,8 @@ void CDataFrameCategoryEncoder::setupFrequencyEncoding(std::size_t numberThreads
541560
}
542561
}
543562
}
563+
LOG_TRACE(<< "mean category frequencies = "
564+
<< core::CContainerPrinter::print(m_MeanCategoryFrequencies));
544565
LOG_TRACE(<< "rare categories = " << core::CContainerPrinter::print(m_RareCategories));
545566
}
546567

@@ -550,11 +571,21 @@ void CDataFrameCategoryEncoder::setupTargetMeanValueEncoding(std::size_t numberT
550571
const TSizeVec& categoricalColumnMask,
551572
std::size_t targetColumn) {
552573

553-
m_TargetMeanValues = CDataFrameUtils::meanValueOfTargetForCategories(
574+
m_CategoryTargetMeanValues = CDataFrameUtils::meanValueOfTargetForCategories(
554575
CDataFrameUtils::CMetricColumnValue{targetColumn}, numberThreads, frame,
555576
rowMask, categoricalColumnMask);
556-
LOG_TRACE(<< "target mean values = "
557-
<< core::CContainerPrinter::print(m_TargetMeanValues));
577+
LOG_TRACE(<< "category target mean values = "
578+
<< core::CContainerPrinter::print(m_CategoryTargetMeanValues));
579+
580+
m_MeanCategoryTargetMeanValues.resize(m_CategoryTargetMeanValues.size());
581+
for (std::size_t i = 0; i < m_CategoryTargetMeanValues.size(); ++i) {
582+
m_MeanCategoryTargetMeanValues[i] =
583+
m_CategoryTargetMeanValues[i].empty()
584+
? 0.0
585+
: CBasicStatistics::mean(m_CategoryTargetMeanValues[i]);
586+
}
587+
LOG_TRACE(<< "mean category target mean values = "
588+
<< core::CContainerPrinter::print(m_MeanCategoryTargetMeanValues));
558589
}
559590

560591
CDataFrameCategoryEncoder::TSizeSizePrDoubleMap
@@ -654,9 +685,9 @@ CDataFrameCategoryEncoder::selectFeatures(std::size_t numberThreads,
654685
metricColumnMask.end(), feature));
655686
} // else if (selected.isTargetMean()) { nothing to do }
656687

657-
auto columnValue = selected.columnValue(m_RareCategories[feature],
658-
m_CategoryFrequencies[feature],
659-
m_TargetMeanValues[feature]);
688+
auto columnValue = selected.columnValue(
689+
m_RareCategories[feature], m_CategoryFrequencies[feature],
690+
m_CategoryTargetMeanValues[feature]);
660691
mics = this->mics(numberThreads, frame, *columnValue, rowMask,
661692
metricColumnMask, categoricalColumnMask);
662693
search.update(mics);
@@ -679,6 +710,38 @@ CDataFrameCategoryEncoder::selectFeatures(std::size_t numberThreads,
679710
void CDataFrameCategoryEncoder::finishEncoding(std::size_t targetColumn,
680711
TSizeSizePrDoubleMap selectedFeatureMics) {
681712

713+
using TMeanAccumulator = CBasicStatistics::SSampleMean<double>::TAccumulator;
714+
715+
// Update the frequency and target mean encoding for one-hot and rare categories.
716+
717+
for (std::size_t i = 0; i < m_OneHotEncodedCategories.size(); ++i) {
718+
TMeanAccumulator meanCategoryFrequency;
719+
TMeanAccumulator meanCategoryTargetMeanValue;
720+
for (auto category : m_OneHotEncodedCategories[i]) {
721+
double frequency{m_CategoryFrequencies[i][category]};
722+
double mean{m_CategoryTargetMeanValues[i][category]};
723+
meanCategoryFrequency.add(frequency, frequency);
724+
meanCategoryTargetMeanValue.add(mean, frequency);
725+
}
726+
for (auto category : m_OneHotEncodedCategories[i]) {
727+
m_CategoryFrequencies[i][category] = CBasicStatistics::mean(meanCategoryFrequency);
728+
m_CategoryTargetMeanValues[i][category] =
729+
CBasicStatistics::mean(meanCategoryTargetMeanValue);
730+
}
731+
}
732+
for (std::size_t i = 0; i < m_RareCategories.size(); ++i) {
733+
TMeanAccumulator meanCategoryTargetMeanValue;
734+
for (auto category : m_RareCategories[i]) {
735+
double frequency{m_CategoryFrequencies[i][category]};
736+
double mean{m_CategoryTargetMeanValues[i][category]};
737+
meanCategoryTargetMeanValue.add(mean, frequency);
738+
}
739+
for (auto category : m_RareCategories[i]) {
740+
m_CategoryTargetMeanValues[i][category] =
741+
CBasicStatistics::mean(meanCategoryTargetMeanValue);
742+
}
743+
}
744+
682745
// Fill in a mapping from encoded column indices to raw column indices.
683746

684747
selectedFeatureMics[{targetColumn, CATEGORY_FOR_DEPENDENT_VARIABLE}] = 0.0;

lib/maths/unittest/CBoostedTreeTest.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ auto predictAndComputeEvaluationMetrics(const F& generateFunction,
118118
for (auto row = beginRows; row != endRows; ++row) {
119119
double targetValue{row->index() < trainRows
120120
? target(*row) + noise[row->index()]
121-
: std::numeric_limits<double>::quiet_NaN()};
121+
: core::CDataFrame::valueOfMissing()};
122122
row->writeColumn(cols - 1, targetValue);
123123
}
124124
});
@@ -582,7 +582,7 @@ void CBoostedTreeTest::testCategoricalRegressors() {
582582
for (auto row = beginRows; row != endRows; ++row) {
583583
double targetValue{row->index() < trainRows
584584
? target(*row)
585-
: std::numeric_limits<double>::quiet_NaN()};
585+
: core::CDataFrame::valueOfMissing()};
586586
row->writeColumn(cols - 1, targetValue);
587587
}
588588
});
@@ -602,8 +602,8 @@ void CBoostedTreeTest::testCategoricalRegressors() {
602602

603603
LOG_DEBUG(<< "bias = " << modelBias);
604604
LOG_DEBUG(<< " R^2 = " << modelRSquared);
605-
CPPUNIT_ASSERT_DOUBLES_EQUAL(0.0, modelBias, 0.06);
606-
CPPUNIT_ASSERT(modelRSquared > 0.97);
605+
CPPUNIT_ASSERT_DOUBLES_EQUAL(0.0, modelBias, 0.1);
606+
CPPUNIT_ASSERT(modelRSquared > 0.9);
607607
}
608608

609609
void CBoostedTreeTest::testProgressMonitoring() {

0 commit comments

Comments
 (0)