Skip to content

Commit

Permalink
Fixes (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Sep 7, 2023
1 parent 4f75aae commit 8fc2e5e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
16 changes: 8 additions & 8 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ def _fit(self, data):
n_components=self.max_clusters,
weight_concentration_prior_type='dirichlet_process',
weight_concentration_prior=0.001,
n_init=1,
random_state=self._get_current_random_seed()
)

Expand Down Expand Up @@ -494,10 +493,12 @@ def _transform(self, data):

data = data.reshape((len(data), 1))
means = self._bgm_transformer.means_.reshape((1, self.max_clusters))

means = means[:, self.valid_component_indicator]
stds = np.sqrt(self._bgm_transformer.covariances_).reshape((1, self.max_clusters))
stds = stds[:, self.valid_component_indicator]

# Multiply stds by 4 so that a value will be in the range [-1,1] with 99.99% probability
normalized_values = (data - means) / (self.STD_MULTIPLIER * stds)
normalized_values = normalized_values[:, self.valid_component_indicator]
component_probs = self._bgm_transformer.predict_proba(data)
component_probs = component_probs[:, self.valid_component_indicator]

Expand All @@ -524,7 +525,8 @@ def _reverse_transform_helper(self, data):
normalized = np.clip(data[:, 0], -1, 1)
means = self._bgm_transformer.means_.reshape([-1])
stds = np.sqrt(self._bgm_transformer.covariances_).reshape([-1])
selected_component = data[:, 1].astype(int) # maybe round instead?
selected_component = data[:, 1].round().astype(int)
selected_component = selected_component.clip(0, self.valid_component_indicator.sum() - 1)
std_t = stds[self.valid_component_indicator][selected_component]
mean_t = means[self.valid_component_indicator][selected_component]
reversed_data = normalized * self.STD_MULTIPLIER * std_t + mean_t
Expand All @@ -546,8 +548,6 @@ def _reverse_transform(self, data):

recovered_data = self._reverse_transform_helper(data)
if self.null_transformer and self.null_transformer.models_missing_values():
data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013
else:
data = recovered_data
recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013

return super()._reverse_transform(data)
return super()._reverse_transform(recovered_data)
1 change: 0 additions & 1 deletion tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,6 @@ def test__fit(self, mock_bgm):
n_components=10,
weight_concentration_prior_type='dirichlet_process',
weight_concentration_prior=0.001,
n_init=1,
random_state=0
)

Expand Down

0 comments on commit 8fc2e5e

Please sign in to comment.