Skip to content

Commit a6b0329

Browse files
authored
[ML] Fix issues upgrading state leading to possible abort of the autodetect process (#140)
Backport #136.
1 parent 4d0b8f0 commit a6b0329

File tree

6 files changed

+75
-22
lines changed

6 files changed

+75
-22
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
* Add control message to start background persistence {ml-pull}19[#19]
3434
* Fail start up if state is missing {ml-pull}4[#4]
3535
* Do not log incorrect model memory limit {ml-pull}3[#3]
36+
* The trend decomposition state wasn't being correctly upgraded potentially causing the autodetect process to abort {ml-pull}136[#136] (issue: {ml-issue}135[#135])
3637

3738
=== Regressions
3839

3940
=== Known Issues
40-

include/core/CStateMachine.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <cstddef>
1717
#include <list>
18+
#include <map>
1819
#include <vector>
1920

2021
namespace ml {
@@ -67,6 +68,7 @@ class CORE_EXPORT CStateMachine {
6768
using TSizeVec = std::vector<std::size_t>;
6869
using TSizeVecVec = std::vector<TSizeVec>;
6970
using TStrVec = std::vector<std::string>;
71+
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
7072

7173
public:
7274
//! Set the number of machines we expect the program to use.
@@ -85,7 +87,8 @@ class CORE_EXPORT CStateMachine {
8587
//! \name Persistence
8688
//@{
8789
//! Initialize by reading state from \p traverser.
88-
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser);
90+
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser,
91+
const TSizeSizeMap& mapping = TSizeSizeMap());
8992

9093
//! Persist state by passing information to the supplied inserter.
9194
void acceptPersistInserter(CStatePersistInserter& inserter) const;

lib/core/CStateMachine.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace core {
2525
namespace {
2626

2727
// CStateMachine
28-
const std::string MACHINE_TAG("a");
28+
//const std::string MACHINE_TAG("a"); No longer used
2929
const std::string STATE_TAG("b");
3030

3131
// CStateMachine::SMachine
@@ -88,17 +88,26 @@ CStateMachine CStateMachine::create(const TStrVec& alphabet,
8888
return result;
8989
}
9090

91-
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
91+
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser,
92+
const TSizeSizeMap& mapping) {
9293
do {
9394
const std::string& name = traverser.name();
94-
RESTORE_BUILT_IN(MACHINE_TAG, m_Machine)
9595
RESTORE_BUILT_IN(STATE_TAG, m_State)
9696
} while (traverser.next());
97+
if (mapping.size() > 0) {
98+
auto mapped = mapping.find(m_State);
99+
if (mapped != mapping.end()) {
100+
m_State = mapped->second;
101+
} else {
102+
LOG_ERROR(<< "Bad mapping '" << core::CContainerPrinter::print(mapping)
103+
<< "' state = " << m_State);
104+
return false;
105+
}
106+
}
97107
return true;
98108
}
99109

100110
void CStateMachine::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
101-
inserter.insertValue(MACHINE_TAG, m_Machine);
102111
inserter.insertValue(STATE_TAG, m_State);
103112
}
104113

@@ -201,14 +210,15 @@ void CStateMachine::CMachineDeque::capacity(std::size_t capacity) {
201210
m_Capacity = capacity;
202211
}
203212

204-
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos) const {
213+
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos_) const {
214+
std::size_t pos{pos_};
205215
for (const auto& machines : m_Machines) {
206216
if (pos < machines.size()) {
207217
return machines[pos];
208218
}
209219
pos -= machines.size();
210220
}
211-
LOG_ABORT(<< "Invalid index '" << pos << "'");
221+
LOG_ABORT(<< "Invalid index '" << pos_ << "'");
212222
}
213223

214224
std::size_t CStateMachine::CMachineDeque::size() const {

lib/core/unittest/CStateMachineTest.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ void CStateMachineTest::testPersist() {
173173
core::CRapidXmlStateRestoreTraverser traverser(parser);
174174

175175
core::CStateMachine restored = core::CStateMachine::create(
176-
machine[1].s_Alphabet, machine[1].s_States, machine[1].s_TransitionFunction,
176+
machine[0].s_Alphabet, machine[0].s_States, machine[0].s_TransitionFunction,
177177
0); // initial state
178-
traverser.traverseSubLevel(
179-
boost::bind(&core::CStateMachine::acceptRestoreTraverser, &restored, _1));
178+
traverser.traverseSubLevel([&restored](core::CStateRestoreTraverser& traverser_) {
179+
return restored.acceptRestoreTraverser(traverser_);
180+
});
180181

181182
CPPUNIT_ASSERT_EQUAL(original.checksum(), restored.checksum());
182183
std::string newXml;

lib/maths/CTimeSeriesDecompositionDetail.cc

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444
#include <algorithm>
4545
#include <cmath>
46+
#include <map>
47+
#include <numeric>
4648
#include <string>
4749
#include <vector>
4850

@@ -57,6 +59,7 @@ using TBoolVec = std::vector<bool>;
5759
using TDoubleVec = std::vector<double>;
5860
using TSizeVec = std::vector<std::size_t>;
5961
using TSizeVecVec = std::vector<TSizeVec>;
62+
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
6063
using TStrVec = std::vector<std::string>;
6164
using TTimeVec = std::vector<core_t::TTime>;
6265
using TTimeTimePr = std::pair<core_t::TTime, core_t::TTime>;
@@ -307,7 +310,7 @@ const std::string LAST_UPDATE_OLD_TAG{"j"};
307310

308311
//////////////////////// Upgrade to Version 6.3 ////////////////////////
309312

310-
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3{48.0};
313+
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3{48.0};
311314

312315
bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
313316
CTrendComponent& trend,
@@ -330,7 +333,7 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
330333

331334
// Generate some samples from the old trend model.
332335

333-
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3 *
336+
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3 *
334337
static_cast<double>(bucketLength) / static_cast<double>(4 * WEEK)};
335338

336339
CPRNG::CXorOShiro128Plus rng;
@@ -343,6 +346,18 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
343346
return true;
344347
}
345348

349+
// This implements the mapping from restored states to their best
350+
// equivalents; specifically:
351+
// SC_NEW_COMPONENTS |-> SC_NEW_COMPONENTS
352+
// SC_NORMAL |-> SC_NORMAL
353+
// SC_FORECASTING |-> SC_NORMAL
354+
// SC_DISABLED |-> SC_DISABLED
355+
// SC_ERROR |-> SC_ERROR
356+
// Note that we don't try and restore the periodicity test state
357+
// (see CTimeSeriesDecomposition::acceptRestoreTraverser) and the
358+
// calendar test state is unchanged.
359+
const TSizeSizeMap SC_STATES_UPGRADING_TO_VERSION_6_3{{0, 0}, {1, 1}, {2, 1}, {3, 2}, {4, 3}};
360+
346361
////////////////////////////////////////////////////////////////////////
347362

348363
// Constants
@@ -482,8 +497,9 @@ bool CTimeSeriesDecompositionDetail::CPeriodicityTest::acceptRestoreTraverser(
482497
do {
483498
const std::string& name{traverser.name()};
484499
RESTORE(PERIODICITY_TEST_MACHINE_6_3_TAG,
485-
traverser.traverseSubLevel(boost::bind(
486-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
500+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
501+
return m_Machine.acceptRestoreTraverser(traverser_);
502+
}))
487503
RESTORE_SETUP_TEARDOWN(
488504
SHORT_WINDOW_6_3_TAG, m_Windows[E_Short].reset(this->newWindow(E_Short)),
489505
m_Windows[E_Short] && traverser.traverseSubLevel(boost::bind(
@@ -759,8 +775,9 @@ bool CTimeSeriesDecompositionDetail::CCalendarTest::acceptRestoreTraverser(core:
759775
do {
760776
const std::string& name{traverser.name()};
761777
RESTORE(CALENDAR_TEST_MACHINE_6_3_TAG,
762-
traverser.traverseSubLevel(boost::bind(
763-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
778+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
779+
return m_Machine.acceptRestoreTraverser(traverser_);
780+
}))
764781
RESTORE_BUILT_IN(LAST_MONTH_6_3_TAG, m_LastMonth);
765782
RESTORE_SETUP_TEARDOWN(
766783
CALENDAR_TEST_6_3_TAG, m_Test.reset(new CCalendarCyclicTest(m_DecayRate)),
@@ -963,8 +980,9 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(core::C
963980
while (traverser.next()) {
964981
const std::string& name{traverser.name()};
965982
RESTORE(COMPONENTS_MACHINE_6_3_TAG,
966-
traverser.traverseSubLevel(boost::bind(
967-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
983+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
984+
return m_Machine.acceptRestoreTraverser(traverser_);
985+
}))
968986
RESTORE_BUILT_IN(DECAY_RATE_6_3_TAG, m_DecayRate);
969987
RESTORE(TREND_6_3_TAG,
970988
traverser.traverseSubLevel(boost::bind(
@@ -995,8 +1013,10 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(core::C
9951013
do {
9961014
const std::string& name{traverser.name()};
9971015
RESTORE(COMPONENTS_MACHINE_OLD_TAG,
998-
traverser.traverseSubLevel(boost::bind(
999-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
1016+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
1017+
return m_Machine.acceptRestoreTraverser(
1018+
traverser_, SC_STATES_UPGRADING_TO_VERSION_6_3);
1019+
}))
10001020
RESTORE_SETUP_TEARDOWN(TREND_OLD_TAG,
10011021
/**/,
10021022
traverser.traverseSubLevel(boost::bind(
@@ -1017,7 +1037,7 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(core::C
10171037
/**/)
10181038
} while (traverser.next());
10191039

1020-
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3);
1040+
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3);
10211041
}
10221042
return true;
10231043
}
@@ -1679,6 +1699,7 @@ bool CTimeSeriesDecompositionDetail::CComponents::SSeasonal::acceptRestoreTraver
16791699
RESTORE(ERRORS_OLD_TAG,
16801700
core::CPersistUtils::restore(ERRORS_OLD_TAG, s_PredictionErrors, traverser))
16811701
} while (traverser.next());
1702+
s_PredictionErrors.resize(s_Components.size());
16821703
}
16831704
return true;
16841705
}
@@ -1907,6 +1928,7 @@ bool CTimeSeriesDecompositionDetail::CComponents::SCalendar::acceptRestoreTraver
19071928
RESTORE(ERRORS_OLD_TAG,
19081929
core::CPersistUtils::restore(ERRORS_OLD_TAG, s_PredictionErrors, traverser))
19091930
} while (traverser.next());
1931+
s_PredictionErrors.resize(s_Components.size());
19101932
}
19111933
return true;
19121934
}

lib/maths/unittest/CTimeSeriesDecompositionTest.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,6 +2108,9 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
21082108
// Check we can validly upgrade existing state.
21092109

21102110
using TStrVec = std::vector<std::string>;
2111+
using TDouble3Vec = core::CSmallVector<double, 3>;
2112+
using TDouble3VecVec = std::vector<TDouble3Vec>;
2113+
21112114
auto load = [](const std::string& name, std::string& result) {
21122115
std::ifstream file;
21132116
file.open(name);
@@ -2186,6 +2189,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
21862189
CPPUNIT_ASSERT_DOUBLES_EQUAL(expectedScale.second, scale.second,
21872190
0.005 * std::max(expectedScale.second, 0.4));
21882191
}
2192+
2193+
// Check some basic operations on the upgraded model.
2194+
TDouble3VecVec forecast;
2195+
decomposition.forecast(60480000, 60480000 + WEEK, HALF_HOUR, 90.0, 1.0, forecast);
2196+
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
2197+
decomposition.addPoint(time, 10.0);
2198+
}
21892199
}
21902200

21912201
LOG_DEBUG(<< "*** Trend and Seasonal Components ***");
@@ -2268,6 +2278,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
22682278
LOG_DEBUG(<< "Mean scale error = " << maths::CBasicStatistics::mean(meanScaleError));
22692279
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanValueError) < 0.06);
22702280
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanScaleError) < 0.07);
2281+
2282+
// Check some basic operations on the upgraded model.
2283+
TDouble3VecVec forecast;
2284+
decomposition.forecast(10366200, 10366200 + WEEK, HALF_HOUR, 90.0, 1.0, forecast);
2285+
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
2286+
decomposition.addPoint(time, 10.0);
2287+
}
22712288
}
22722289
}
22732290

0 commit comments

Comments
 (0)