Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CategoricalTransformer with NaN values cannot be pickled #164

Closed
csala opened this issue May 21, 2021 · 0 comments · Fixed by #165
Closed

CategoricalTransformer with NaN values cannot be pickled #164

csala opened this issue May 21, 2021 · 0 comments · Fixed by #165
Assignees
Labels
bug Something isn't working
Milestone

Comments

@csala
Copy link
Contributor

csala commented May 21, 2021

The CategoricalTransformer becomes unusable after being pickled and unpickled if it had NaN values in the data which it was fit on.

The problem seems to come from this line:

mean, std = self.intervals[category][2:]

When the transformer is fitted, one of the intervals takes np.nan as the key, so later on np.nan can be properly transformed.
However, if the transformer is pickled and unpickled, the nan value that ends up being used as the key for the intervals dictionary is a different float('nan') instances, which prevents the transformer from being able to transform a np.nan value.

Here is a snippet reproducing the issue:

>>> import pickle
>>> import numpy as np
>>> from rdt.transformers import CategoricalTransformer
>>> 
>>> nans = np.array([np.nan])
>>> 
>>> cat = CategoricalTransformer()
>>> cat.fit(nans)
>>> cat.transform(nans)
array([0.5])
>>> 
>>> with open('test.pkl', 'wb') as f:
...     pickle.dump(cat, f)
... 
>>> with open('test.pkl', 'rb') as f:
...     cat2 = pickle.load(f)
... 
>>> cat2.transform(nans)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/mnt/nvme0n1p2/xals/.virtualenvs/SDV/lib/python3.8/site-packages/rdt/transformers/categorical.py", line 125, in transform
    return data.fillna(np.nan).apply(self._get_value).to_numpy()
  File "/mnt/nvme0n1p2/xals/.virtualenvs/SDV/lib/python3.8/site-packages/pandas/core/series.py", line 4212, in apply
    mapped = lib.map_infer(values, f, convert=convert_dtype)
  File "pandas/_libs/lib.pyx", line 2403, in pandas._libs.lib.map_infer
  File "/mnt/nvme0n1p2/xals/.virtualenvs/SDV/lib/python3.8/site-packages/rdt/transformers/categorical.py", line 103, in _get_value
    mean, std = self.intervals[category][2:]
KeyError: nan

To fix this, we can add a __setstate__ method to the CategoricalTransformer to replace any null key by the actual np.nan instance:

def __setstate__(self, state):
    intervals = state.get('intervals')
    if intervals:
        for key in list(intervals):
            if pd.isnull(key):
                intervals[np.nan] = intervals.pop(key)

    self.__dict__ = state
@npatki npatki added the bug Something isn't working label May 21, 2021
@csala csala added this to the 0.4.2 milestone May 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants