Closed
Description
Environment Details
Please indicate the following details about the environment in which you found the bug:
- RDT version: 0.5.2.dev0
- Python version: 3.8
- Operating System: macOS 11.5.1
Error Description
After optimizing the OneHotEncoder Transformer, CopulaGAN
, CTGAN
, and TVAE
start failing in SDV. The primary reason for this failure is caused by the line:
where DataTransformer access ohe.dummies
. In PR #186, we altered the contents of self.dummies
to not include NaN values (because pd.Categorical
doesn't accept it as a category), which inconsequence makes us initialize weights of the the wrong dimensions in consequent parts of the code because we are missing the NaN value in ohe.dummies
.
Steps to reproduce
Initialize any dataframe with missing values and it will crash
import pandas as pd
from sdv.tabular import CTGAN
data = pd.DataFrame({"category": ["1", "2", "3", None] * 4})
ctgan = CTGAN()
ctgan.fit(data)
the code crashes with the following message
AssertionError Traceback (most recent call last)
<ipython-input-8-8847fc3cae96> in <module>
----> 1 ctgan.fit(data)
~/Downloads/repos/SDV/sdv/tabular/base.py in fit(self, data)
140 LOGGER.debug(
141 'Fitting %s model to table %s', self.__class__.__name__, self._metadata.name)
--> 142 self._fit(transformed)
143
144 def get_metadata(self):
~/Downloads/repos/SDV/sdv/tabular/ctgan.py in _fit(self, table_data)
55 categoricals.append(field)
56
---> 57 self._model.fit(
58 table_data,
59 discrete_columns=categoricals
~/opt/anaconda3/envs/rdt/lib/python3.8/site-packages/ctgan/synthesizers/ctgan.py in fit(self, train_data, discrete_columns, epochs)
296 train_data = self._transformer.transform(train_data)
297
--> 298 self._data_sampler = DataSampler(
299 train_data,
300 self._transformer.output_info_list,
~/opt/anaconda3/envs/rdt/lib/python3.8/site-packages/ctgan/data_sampler.py in __init__(self, data, output_info, log_frequency)
37 else:
38 st += sum([span_info.dim for span_info in column_info])
---> 39 assert st == data.shape[1]
40
41 # Prepare an interval matrix for efficiently sample conditional vector
AssertionError:
Steps to fix
- revisit the variables in
OneHotEncoder
and letdummies
containnp.nan
when that's the case. - add a unit test to confirm
np.nan
is part ofdummies
whendummy_na
is set to True.