Skip to content

Commit b73df65

Browse files
authored
[ML] Improve robustness to very low variance data (#232)
1 parent ece8ca0 commit b73df65

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

include/maths/CPeriodicityHypothesisTests.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class MATHS_EXPORT CPeriodicityHypothesisTests {
208208

209209
//! \brief A collection of statistics used during testing.
210210
struct STestStats {
211-
STestStats();
211+
explicit STestStats(double meanMagnitude);
212212
//! Set the various test thresholds.
213213
void setThresholds(double vt, double at, double Rt);
214214
//! Check if the null hypothesis is good enough to not need an
@@ -233,6 +233,8 @@ class MATHS_EXPORT CPeriodicityHypothesisTests {
233233
double s_NonEmptyBuckets;
234234
//! The average number of measurements per bucket value.
235235
double s_MeasurementsPerBucket;
236+
//! The mean magnitude of the bucket values.
237+
double s_MeanMagnitude;
236238
//! The null hypothesis periodic components.
237239
CPeriodicityHypothesisTestsResult s_H0;
238240
//! The variance estimate of H0.

lib/maths/CPeriodicityHypothesisTests.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,8 +1154,16 @@ CPeriodicityHypothesisTests::best(const TNestedHypothesesVec& hypotheses) const
11541154
THypothesisSummaryVec summaries;
11551155
summaries.reserve(hypotheses.size());
11561156

1157+
double meanMagnitude{CBasicStatistics::mean(std::accumulate(
1158+
m_BucketValues.begin(), m_BucketValues.end(), TMeanAccumulator{},
1159+
[](TMeanAccumulator partial, const TFloatMeanAccumulator& value) {
1160+
partial.add(std::fabs(CBasicStatistics::mean(value)),
1161+
CBasicStatistics::count(value));
1162+
return partial;
1163+
}))};
1164+
11571165
for (const auto& hypothesis : hypotheses) {
1158-
STestStats stats;
1166+
STestStats stats{meanMagnitude};
11591167
stats.s_TrendSegments = static_cast<double>(hypothesis.trendSegments());
11601168
CPeriodicityHypothesisTestsResult resultForHypothesis{hypothesis.test(stats)};
11611169
if (stats.s_NonEmptyBuckets > stats.s_DF0) {
@@ -2195,13 +2203,14 @@ bool CPeriodicityHypothesisTests::testAmplitude(const TTimeTimePr2Vec& window,
21952203
const double CPeriodicityHypothesisTests::ACCURATE_TEST_POPULATED_FRACTION{0.9};
21962204
const double CPeriodicityHypothesisTests::MINIMUM_COEFFICIENT_OF_VARIATION{1e-4};
21972205

2198-
CPeriodicityHypothesisTests::STestStats::STestStats()
2206+
CPeriodicityHypothesisTests::STestStats::STestStats(double meanMagnitude)
21992207
: s_TrendSegments(1.0), s_HasPeriod(false), s_HasPartition(false),
22002208
s_VarianceThreshold(COMPONENT_SIGNIFICANT_VARIANCE_REDUCTION[E_HighThreshold]),
22012209
s_AmplitudeThreshold(SEASONAL_SIGNIFICANT_AMPLITUDE[E_HighThreshold]),
22022210
s_AutocorrelationThreshold(SEASONAL_SIGNIFICANT_AUTOCORRELATION[E_HighThreshold]),
22032211
s_Range(0.0), s_NonEmptyBuckets(0.0), s_MeasurementsPerBucket(0.0),
2204-
s_V0(0.0), s_R0(0.0), s_DF0(0.0), s_StartOfPartition(0) {
2212+
s_MeanMagnitude(meanMagnitude), s_V0(0.0), s_R0(0.0), s_DF0(0.0),
2213+
s_StartOfPartition(0) {
22052214
}
22062215

22072216
void CPeriodicityHypothesisTests::STestStats::setThresholds(double vt, double at, double Rt) {
@@ -2211,16 +2220,7 @@ void CPeriodicityHypothesisTests::STestStats::setThresholds(double vt, double at
22112220
}
22122221

22132222
bool CPeriodicityHypothesisTests::STestStats::nullHypothesisGoodEnough() const {
2214-
TMeanAccumulator mean;
2215-
for (const auto& t : s_T0) {
2216-
mean += std::accumulate(t.begin(), t.end(), TMeanAccumulator(),
2217-
[](TMeanAccumulator m, double x) {
2218-
m.add(std::fabs(x));
2219-
return m;
2220-
});
2221-
}
2222-
return std::sqrt(s_V0) <=
2223-
MINIMUM_COEFFICIENT_OF_VARIATION * CBasicStatistics::mean(mean);
2223+
return std::sqrt(s_V0) <= MINIMUM_COEFFICIENT_OF_VARIATION * s_MeanMagnitude;
22242224
}
22252225

22262226
CPeriodicityHypothesisTests::CNestedHypotheses::CNestedHypotheses(TTestFunc test)

lib/maths/CTimeSeriesDecompositionDetail.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,14 +1215,14 @@ void CTimeSeriesDecompositionDetail::CComponents::handle(const SAddValue& messag
12151215
for (std::size_t i = 1u; i <= m; ++i) {
12161216
CSeasonalComponent* component{seasonalComponents[i - 1]};
12171217
CComponentErrors* error_{seasonalErrors[i - 1]};
1218-
double varianceIncrease{variances[i] / variance / expectedVarianceIncrease};
1218+
double varianceIncrease{variance == 0.0 ? 1.0 : variances[i] / variance / expectedVarianceIncrease};
12191219
component->add(time, values[i], weight);
12201220
error_->add(referenceError, error, predictions[i - 1], varianceIncrease, weight);
12211221
}
12221222
for (std::size_t i = m + 1; i <= m + n; ++i) {
12231223
CCalendarComponent* component{calendarComponents[i - m - 1]};
12241224
CComponentErrors* error_{calendarErrors[i - m - 1]};
1225-
double varianceIncrease{variances[i] / variance / expectedVarianceIncrease};
1225+
double varianceIncrease{variance == 0.0 ? 1.0 : variances[i] / variance / expectedVarianceIncrease};
12261226
component->add(time, values[i], weight);
12271227
error_->add(referenceError, error, predictions[i - 1], varianceIncrease, weight);
12281228
}

0 commit comments

Comments
 (0)