From 2c8a046587b2197352c3afa09a9ec50838e79121 Mon Sep 17 00:00:00 2001 From: saroele Date: Mon, 5 Mar 2018 22:25:19 +0100 Subject: [PATCH 1/4] MWE and fix for #35 --- opengrid/library/regression.py | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/opengrid/library/regression.py b/opengrid/library/regression.py index 8770a86..3f3a6b4 100644 --- a/opengrid/library/regression.py +++ b/opengrid/library/regression.py @@ -438,3 +438,41 @@ def plot(self, model=True, bar_chart=True, **kwargs): plt.show() return figures + + +class TestPickle(object): + """ + Examples + -------- + >>> from opengrid.library.regression import TestPickle + >>> import pickle + >>> tp = TestPickle('test') + >>> pickle.dump(tp, open('test.pkl', 'wb')) + """ + def __init__(self, x): + setattr(self, 'x', [Term([LookupFactor('endog')])]) + + def __getstate__(self): + d = self.__dict__ + d['temp'] = self.x[0].factors[0].name() + d.pop('x') + print("pickling, d={}".format(d)) + return d + + def __setstate__(self, state): + setattr(self, 'x', [Term([LookupFactor(state['temp'])])]) + +import attr + +@attr.s +class TestPickle2(object): + """ + Examples + -------- + >>> from opengrid.library.regression import TestPickle2 + >>> import pickle + >>> tp = TestPickle2() + >>> pickle.dump(tp, open('test.pkl', 'wb')) + """ + x = attr.ib([Term([LookupFactor('endog')])]) + From 4119b10939902e3dbb3f13316f18c28f5eb05dbc Mon Sep 17 00:00:00 2001 From: saroele Date: Thu, 8 Mar 2018 09:19:53 +0100 Subject: [PATCH 2/4] Removing all formulas from fits but resulting object still does not pickle #35 --- opengrid/library/regression.py | 72 ++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 13 deletions(-) diff --git a/opengrid/library/regression.py b/opengrid/library/regression.py index 3f3a6b4..d76064b 100644 --- a/opengrid/library/regression.py +++ b/opengrid/library/regression.py @@ -439,6 +439,65 @@ def plot(self, model=True, bar_chart=True, **kwargs): return figures + def _modeldesc_to_dict(self, md): + """Return a string representation of a patsy ModelDesc object""" + d = {'lhs_termlist': [md.lhs_termlist[0].factors[0].name()]} + rhs_termlist = [] + + # add other terms, if any + for term in md.rhs_termlist[:]: + if len(term.factors) == 0: + # intercept, represent by empty string + rhs_termlist.append('') + else: + rhs_termlist.append(term.factors[0].name()) + + d['rhs_termlist'] = rhs_termlist + return d + + def _modeldesc_from_dict(self, d): + """Return a string representation of a patsy ModelDesc object""" + lhs_termlist = [Term([LookupFactor(d['lhs_termlist'][0])])] + rhs_termlist = [] + for name in d['rhs_termlist']: + if name == '': + rhs_termlist.append(Term([])) + else: + rhs_termlist.append(Term([LookupFactor(name)])) + + md = ModelDesc(lhs_termlist, rhs_termlist) + return md + + def __getstate__(self): + """ + Remove attributes that cannot be pickled and store as dict. + + Each fit has a model.formula which is a patsy ModelDesc and this cannot be pickled. + We use our knowledge of this ModelDesc (as we build it up manually in the do_analysis() method) + and decompose it into a dictionary. This dictionary is stored in the list 'formulas', + one dict per fit. + + Of course we have to remove the fit.model.formula entirely, it is built-up again + from self.formulas in the __setstate__ method. + """ + + d = self.__dict__ + d['formulas'] = [] + for fit in self.list_of_fits: + d['formulas'].append(self._modeldesc_to_dict(fit.model.formula)) + delattr(fit.model, 'formula') + + print("Pickling... Removing the 'formula' from each fit.model.\n\ + You have to unpickle your object or run __setstate__ to restore them.".format(d)) + return d + + def __setstate__(self, state): + """Restore the attributes that cannot be pickled""" + for fit, formula in zip(self.list_of_fits, state['formulas']): + fit.model.formula = self._modeldesc_from_dict(formula) + delattr(self, 'formulas') + + class TestPickle(object): """ @@ -462,17 +521,4 @@ def __getstate__(self): def __setstate__(self, state): setattr(self, 'x', [Term([LookupFactor(state['temp'])])]) -import attr - -@attr.s -class TestPickle2(object): - """ - Examples - -------- - >>> from opengrid.library.regression import TestPickle2 - >>> import pickle - >>> tp = TestPickle2() - >>> pickle.dump(tp, open('test.pkl', 'wb')) - """ - x = attr.ib([Term([LookupFactor('endog')])]) From c5bd2f942bdac98870c5fceab1baf6acedb21294 Mon Sep 17 00:00:00 2001 From: saroele Date: Thu, 8 Mar 2018 09:40:58 +0100 Subject: [PATCH 3/4] Remove all fits from a MultiVarLinReg object to allow pickling. They are restored during unpickling --- opengrid/library/regression.py | 44 ++++++++++------------------------ 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/opengrid/library/regression.py b/opengrid/library/regression.py index d76064b..0774f91 100644 --- a/opengrid/library/regression.py +++ b/opengrid/library/regression.py @@ -477,48 +477,28 @@ def __getstate__(self): and decompose it into a dictionary. This dictionary is stored in the list 'formulas', one dict per fit. - Of course we have to remove the fit.model.formula entirely, it is built-up again + Finally we have to remove each fit entirely (not just the formula), it is built-up again from self.formulas in the __setstate__ method. """ - d = self.__dict__ d['formulas'] = [] for fit in self.list_of_fits: d['formulas'].append(self._modeldesc_to_dict(fit.model.formula)) - delattr(fit.model, 'formula') + #delattr(fit.model, 'formula') + d.pop('list_of_fits') + d.pop('fit') print("Pickling... Removing the 'formula' from each fit.model.\n\ - You have to unpickle your object or run __setstate__ to restore them.".format(d)) + You have to unpickle your object or run __setstate__(self.__dict__) to restore them.".format(d)) return d def __setstate__(self, state): """Restore the attributes that cannot be pickled""" - for fit, formula in zip(self.list_of_fits, state['formulas']): - fit.model.formula = self._modeldesc_from_dict(formula) - delattr(self, 'formulas') - - - -class TestPickle(object): - """ - Examples - -------- - >>> from opengrid.library.regression import TestPickle - >>> import pickle - >>> tp = TestPickle('test') - >>> pickle.dump(tp, open('test.pkl', 'wb')) - """ - def __init__(self, x): - setattr(self, 'x', [Term([LookupFactor('endog')])]) - - def __getstate__(self): - d = self.__dict__ - d['temp'] = self.x[0].factors[0].name() - d.pop('x') - print("pickling, d={}".format(d)) - return d - - def __setstate__(self, state): - setattr(self, 'x', [Term([LookupFactor(state['temp'])])]) - + for k,v in state.items(): + if k is not 'formulas': + setattr(self, k, v) + self.list_of_fits = [] + for formula in state['formulas']: + self.list_of_fits.append(fm.ols(self._modeldesc_from_dict(formula), data=self.df).fit()) + self.fit = self.list_of_fits[-1] From 7fd46a5dfaf9b6c06d248e3a7090e06f4453c6f4 Mon Sep 17 00:00:00 2001 From: saroele Date: Thu, 8 Mar 2018 09:55:17 +0100 Subject: [PATCH 4/4] Unittest for pickling round trip --- opengrid/tests/test_regression.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/opengrid/tests/test_regression.py b/opengrid/tests/test_regression.py index 0c95f0f..1cff809 100644 --- a/opengrid/tests/test_regression.py +++ b/opengrid/tests/test_regression.py @@ -11,6 +11,7 @@ import opengrid as og from opengrid import datasets import mock +import pickle plt_mocked = mock.Mock() ax_mock = mock.Mock() @@ -112,6 +113,22 @@ def test_prune(self): self.assertFalse("ba14" in mvlr.fit.model.exog_names) self.assertFalse("d5a7" in mvlr.fit.model.exog_names) + def test_pickle_round_trip(self): + "Pickle, unpickle and check results" + df = datasets.get('gas_2016_hour') + df_month = df.resample('MS').sum().loc['2016', :] + df_training = df_month.iloc[:-1, :] + df_pred = df_month.iloc[[-1], :] + mvlr = og.MultiVarLinReg(df_training, '313b', p_max=0.04) + mvlr.do_analysis() + df_pred_95_orig = mvlr._predict(mvlr.fit, df=df_pred) + + s = pickle.dumps(mvlr) + m = pickle.loads(s) + self.assertTrue(hasattr(m, 'list_of_fits')) + df_pred_95_roundtrip = m._predict(m.fit, df=df_pred) + self.assertAlmostEqual(df_pred_95_orig.loc['2016-12-01', 'predicted'], df_pred_95_roundtrip.loc['2016-12-01', 'predicted']) + if __name__ == '__main__': unittest.main()