Closed
Description
Environment Details
- RDT version: 1.3.0
- Python version: 3.8
- Operating System: Linux (Colab Notebook)
Error Description
When I use a OneHotEncoder
that includes nan
values, the transformer correctly works. If I save/reload the object, then there is a unnecessary warning claiming that there are unseen categories (nan
).
/usr/local/lib/python3.8/dist-packages/rdt/transformers/categorical.py:381: UserWarning: The data contains
1 new categories that were not seen in the original data (examples: {nan}).
Creating a vector of all 0s. If you want to model new categories, please fit the transformer again with the
new data. warnings.warn(
The warning does not occur unless you save/reload the object. The warning does not seem to be necessary, as the transformer is correctly identifying and transforming nan
values.
Steps to reproduce
import numpy as np
import pandas as pd
import pickle
from rdt import HyperTransformer
from rdt.transformers.categorical import LabelEncoder, OneHotEncoder
data = pd.DataFrame(data={
'column_name': [1.0, 2.0, np.nan, 2.0, 3.0, np.nan, 3.0] })
ht = HyperTransformer()
ht.set_config({
'sdtypes': { 'column_name': 'categorical' },
'transformers': { 'column_name': OneHotEncoder() } })
ht.fit(data)
with open('ht.pkl', "wb") as f:
pickle.dump(ht, f)
with open('ht.pkl', 'rb') as f:
ht_loaded = pickle.load(f)
transformed = ht_loaded.transform(data)
Additional Context
After saving/reloading the object, the check in the following line seems to be producing {nan}
:
unseen_categories = unique_data - set(self.dummies)
It's unclear why. As the nan
values are the same.
>>> ohe = ht_loaded.get_config()['transformers']['column_name']
>>> set(ohe.dummies)
{ nan, 1.0, 2.0, 3.0 }
>>> d = list(data['column_name'])
>>> unique_data = { np.nan if pd.isna(x) else x for x in pd.unique(d) }
>>> unique_data
{ nan, 1.0, 2.0, 3.0 }
>>> unique_data - set(ohe.dummies)
{ nan }
Other transformers such as LabelEncoder
use a different check so maybe it's worth using that.
mapped = data.fillna(np.nan).map(self.categories_to_values)
is_null = mapped.isna()