-
Notifications
You must be signed in to change notification settings - Fork 27
Closed
Description
Environment Details
Please indicate the following details about the environment in which you found the bug:
- RDT version: 0.5.2.dev0
- Python version: 3.8
- Operating System: macOS 11.5.1
Error Description
After optimizing the OneHotEncoder Transformer, CopulaGAN, CTGAN, and TVAE start failing in SDV. The primary reason for this failure is caused by the line:
where DataTransformer access ohe.dummies. In PR #186, we altered the contents of self.dummies to not include NaN values (because pd.Categorical doesn't accept it as a category), which inconsequence makes us initialize weights of the the wrong dimensions in consequent parts of the code because we are missing the NaN value in ohe.dummies.
Steps to reproduce
Initialize any dataframe with missing values and it will crash
import pandas as pd
from sdv.tabular import CTGAN
data = pd.DataFrame({"category": ["1", "2", "3", None] * 4})
ctgan = CTGAN()
ctgan.fit(data)the code crashes with the following message
AssertionError Traceback (most recent call last)
<ipython-input-8-8847fc3cae96> in <module>
----> 1 ctgan.fit(data)
~/Downloads/repos/SDV/sdv/tabular/base.py in fit(self, data)
140 LOGGER.debug(
141 'Fitting %s model to table %s', self.__class__.__name__, self._metadata.name)
--> 142 self._fit(transformed)
143
144 def get_metadata(self):
~/Downloads/repos/SDV/sdv/tabular/ctgan.py in _fit(self, table_data)
55 categoricals.append(field)
56
---> 57 self._model.fit(
58 table_data,
59 discrete_columns=categoricals
~/opt/anaconda3/envs/rdt/lib/python3.8/site-packages/ctgan/synthesizers/ctgan.py in fit(self, train_data, discrete_columns, epochs)
296 train_data = self._transformer.transform(train_data)
297
--> 298 self._data_sampler = DataSampler(
299 train_data,
300 self._transformer.output_info_list,
~/opt/anaconda3/envs/rdt/lib/python3.8/site-packages/ctgan/data_sampler.py in __init__(self, data, output_info, log_frequency)
37 else:
38 st += sum([span_info.dim for span_info in column_info])
---> 39 assert st == data.shape[1]
40
41 # Prepare an interval matrix for efficiently sample conditional vector
AssertionError:
Steps to fix
- revisit the variables in
OneHotEncoderand letdummiescontainnp.nanwhen that's the case. - add a unit test to confirm
np.nanis part ofdummieswhendummy_nais set to True.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working