Skip to content

Commit

Permalink
Merge pull request #36 from opengridcc/issue35
Browse files Browse the repository at this point in the history
Pickling of multivarlinreg objects
  • Loading branch information
JrtPec authored Mar 19, 2018
2 parents 7a91d29 + 7fd46a5 commit 3a6af26
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
64 changes: 64 additions & 0 deletions opengrid/library/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,67 @@ def plot(self, model=True, bar_chart=True, **kwargs):
plt.show()

return figures

def _modeldesc_to_dict(self, md):
"""Return a string representation of a patsy ModelDesc object"""
d = {'lhs_termlist': [md.lhs_termlist[0].factors[0].name()]}
rhs_termlist = []

# add other terms, if any
for term in md.rhs_termlist[:]:
if len(term.factors) == 0:
# intercept, represent by empty string
rhs_termlist.append('')
else:
rhs_termlist.append(term.factors[0].name())

d['rhs_termlist'] = rhs_termlist
return d

def _modeldesc_from_dict(self, d):
"""Return a string representation of a patsy ModelDesc object"""
lhs_termlist = [Term([LookupFactor(d['lhs_termlist'][0])])]
rhs_termlist = []
for name in d['rhs_termlist']:
if name == '':
rhs_termlist.append(Term([]))
else:
rhs_termlist.append(Term([LookupFactor(name)]))

md = ModelDesc(lhs_termlist, rhs_termlist)
return md

def __getstate__(self):
"""
Remove attributes that cannot be pickled and store as dict.
Each fit has a model.formula which is a patsy ModelDesc and this cannot be pickled.
We use our knowledge of this ModelDesc (as we build it up manually in the do_analysis() method)
and decompose it into a dictionary. This dictionary is stored in the list 'formulas',
one dict per fit.
Finally we have to remove each fit entirely (not just the formula), it is built-up again
from self.formulas in the __setstate__ method.
"""
d = self.__dict__
d['formulas'] = []
for fit in self.list_of_fits:
d['formulas'].append(self._modeldesc_to_dict(fit.model.formula))
#delattr(fit.model, 'formula')
d.pop('list_of_fits')
d.pop('fit')

print("Pickling... Removing the 'formula' from each fit.model.\n\
You have to unpickle your object or run __setstate__(self.__dict__) to restore them.".format(d))
return d

def __setstate__(self, state):
"""Restore the attributes that cannot be pickled"""
for k,v in state.items():
if k is not 'formulas':
setattr(self, k, v)
self.list_of_fits = []
for formula in state['formulas']:
self.list_of_fits.append(fm.ols(self._modeldesc_from_dict(formula), data=self.df).fit())
self.fit = self.list_of_fits[-1]

17 changes: 17 additions & 0 deletions opengrid/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import opengrid as og
from opengrid import datasets
import mock
import pickle

plt_mocked = mock.Mock()
ax_mock = mock.Mock()
Expand Down Expand Up @@ -112,6 +113,22 @@ def test_prune(self):
self.assertFalse("ba14" in mvlr.fit.model.exog_names)
self.assertFalse("d5a7" in mvlr.fit.model.exog_names)

def test_pickle_round_trip(self):
"Pickle, unpickle and check results"
df = datasets.get('gas_2016_hour')
df_month = df.resample('MS').sum().loc['2016', :]
df_training = df_month.iloc[:-1, :]
df_pred = df_month.iloc[[-1], :]
mvlr = og.MultiVarLinReg(df_training, '313b', p_max=0.04)
mvlr.do_analysis()
df_pred_95_orig = mvlr._predict(mvlr.fit, df=df_pred)

s = pickle.dumps(mvlr)
m = pickle.loads(s)
self.assertTrue(hasattr(m, 'list_of_fits'))
df_pred_95_roundtrip = m._predict(m.fit, df=df_pred)
self.assertAlmostEqual(df_pred_95_orig.loc['2016-12-01', 'predicted'], df_pred_95_roundtrip.loc['2016-12-01', 'predicted'])


if __name__ == '__main__':
unittest.main()

0 comments on commit 3a6af26

Please sign in to comment.