Skip to content

Commit

Permalink
Fix test for parametric bootstrapping with covariates
Browse files Browse the repository at this point in the history
Covariate models cannot generate samples.
  • Loading branch information
sachaMorin committed Dec 21, 2023
1 parent c7b1ff5 commit 161b805
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
7 changes: 7 additions & 0 deletions stepmix/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ def bootstrap(
check_is_fitted(estimator)
estimator = copy.deepcopy(estimator)
estimator.set_params(random_state=random_state)

if parametric and estimator._is_covariate:
raise ValueError("Parametric bootstrapping is not supported for covariate models.")

if sampler is not None:
if sampler._is_covariate:
raise ValueError("Parametric bootstrapping is not supported for covariate models.")

check_is_fitted(sampler)
sampler = copy.deepcopy(sampler)
sampler.set_params(random_state=random_state)
Expand Down
4 changes: 2 additions & 2 deletions stepmix/stepmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,11 @@ def _check_initial_parameters(self, X):
self.param_buffer_ = list()

# Covariate models have special constraints. Check them.
is_covariate = utils.check_covariate(self.measurement, self.structural)
self._is_covariate = utils.check_covariate(self.measurement, self.structural)

# Covariate models use a different conditional likelihood (See Bakk and Kuha, 2018), which should
# not include the marginal likelihood over the latent classes in the E-step
self._conditional_likelihood = is_covariate
self._conditional_likelihood = self._is_covariate

def _initialize_parameters(self, X, random_state):
"""Initialize the weights and measurement model parameters.
Expand Down
8 changes: 7 additions & 1 deletion test/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ def test_bootstrap(data, kwargs, model):

model_1 = StepMix(n_steps=1, **kwargs)
model_1.fit(X, Y)
model_1.bootstrap_stats(X, Y, n_repetitions=3)

if model is not 'covariate':
model_1.bootstrap_stats(X, Y, n_repetitions=3)
else:
# Should raise error. Can't sample from a covariate model
with pytest.raises(ValueError) as e_info:
model_1.bootstrap_stats(X, Y, n_repetitions=3)


def test_nested_bootstrap(data_nested, kwargs_nested):
Expand Down

0 comments on commit 161b805

Please sign in to comment.