Skip to content

Commit

Permalink
Binary structural inference
Browse files Browse the repository at this point in the history
  • Loading branch information
sachaMorin committed Dec 22, 2023
1 parent ec9824d commit 5a76e9d
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 1 deletion.
70 changes: 70 additions & 0 deletions stepmix/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,76 @@ def data_generation_gaussian(n_samples, sep_level, n_mm=6, random_state=None):

return X, Y, labels

def data_gaussian_binary(n_samples, n_mm=6, random_state=None):
"""Full Gaussian measurement model with 2 binary responses.
The data has 4 latent classes.
Parameters
----------
n_samples : int
Number of samples.
random_state: int
Random state.
Returns
-------
X : ndarray of shape (n_samples, 2)
Gaussian Measurement samples.
Y : ndarray of shape (n_samples, 2)
Binary Structural samples.
labels : ndarray of shape (n_samples,)
Ground truth class membership.
"""
n_classes = 4 # Number of latent classes
n_sm = 2 # Dimensions of the response variable Zo

# True parameters
# rho[c] = p(X=c)
rho = np.ones(n_classes) / n_classes

# mus[k] = E[Z0|X=c]
# mus[k] = E[Z0|X=c]
mus = np.array(
[[-2.0344, 4.1726], [3.9779, 3.7735], [3.8007, -3.7972], [-3.0620, -3.5345]]
)

# sigmas[k] = V[Z0|X=c]
sigmas = np.array(
[
[[2.9044, 0.2066], [0.2066, 2.7562]],
[[0.2104, 0.2904], [0.2904, 12.2392]],
[[0.9213, 0.0574], [0.0574, 1.8660]],
[[6.2414, 6.0502], [6.0502, 6.1825]],
]
)

pis = np.array([
[0.1, 0.9, 0.1, 0.9],
[0.8, 0.6, 0.4, 0.2]
])

# Model parameters
params = dict(
weights=rho,
measurement=dict(pis=pis.T),
structural=dict(means=mus, covariances=sigmas),
measurement_in=n_mm,
structural_in=n_sm,
)

# Sample data
generator = StepMix(
n_components=n_classes,
measurement="bernoulli",
structural="gaussian_full",
random_state=random_state,
)
generator.set_parameters(params)
Y, X, labels = generator.sample(n_samples)

return X, Y, labels

def data_gaussian_diag(n_samples, sep_level, n_mm=6, random_state=None, nan_ratio=0.0):
"""Bakk binary measurement model with 2D diagonal gaussian structural model.
Expand Down
8 changes: 8 additions & 0 deletions stepmix/emission/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def log_likelihood(self, X):
log_eps = X @ np.log(pis) + (1 - X) @ np.log(1 - pis)
return log_eps

def predict_proba(self, log_resp):
resp = np.exp(log_resp)
return resp @ self.parameters["pis"]

def predict(self, log_resp):
probs = self.predict_proba(log_resp)
return (probs > 0.5).astype(int)

def sample(self, class_no, n_samples):
feature_weights = self.parameters["pis"][class_no, :].reshape((1, -1))
K = feature_weights.shape[1] # number of features
Expand Down
36 changes: 36 additions & 0 deletions stepmix/emission/emission.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,42 @@ def log_likelihood(self, X):
"""
raise NotImplementedError

def predict_proba(self, log_resp):
"""Compute the conditional probabilities P(Y|X) given the log responsibilities P(Z|X).
This will only be used if the emission model is used as a structural model. X therefore represents the input
and Y the output for supervised predictions.
Parameters
----------
log_resp : ndarray of shape (n_samples, n_components)
Logarithm of the posterior probabilities P(Z|X) (or responsibilities) of each sample.
Returns
-------
resp : ndarray of shape (n_samples, n_columns)
Conditional probabilities P(Y|X) of each sample.
"""
raise NotImplementedError("This emission model does not support predictions.")

def predict(self, log_resp):
"""Compute argmax P(Y|X) given the log responsibilities P(Z|X) for supervised predictions.
This will only be used if the emission model is used as a structural model. X therefore represents the input
and Y the output for supervised predictions.
Parameters
----------
log_resp : ndarray of shape (n_samples, n_components)
Logarithm of the posterior probabilities P(Z|X) (or responsibilities) of each sample.
Returns
-------
resp : ndarray of shape (n_samples, n_columns)
Argmax P(Y|X) of each sample.
"""
raise NotImplementedError("This emission model does not support predictions.")

@abstractmethod
def sample(self, class_no, n_samples):
"""Sample n_samples conditioned on the given class_no.
Expand Down
56 changes: 55 additions & 1 deletion stepmix/stepmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def caic(self, X, Y=None):
return -2 * self.score(X, Y) * n + self.n_parameters * (np.log(n) + 1)

def predict(self, X, Y=None):
"""Predict the labels for the data samples in X using the measurement model.
"""Predict the cluster/latent class labels for the data samples in X using the measurement model.
Optionally, an array-like Y can be provided to predict the labels based on both the measurement and structural
models.
Expand Down Expand Up @@ -1420,6 +1420,60 @@ def predict_proba(self, X, Y=None):

_, log_resp = self._e_step(X, Y=Y)
return np.exp(log_resp)
def predict_Y(self, X, Y=None):
"""Call the predict method of the structural model to predict argmax P(Y|X) (Supervised prediction).
Parameters
----------
X : array-like of shape (n_samples, n_features)
List of n_features-dimensional data points to fit the measurement model. Each row
corresponds to a single data point. If the data is categorical, by default it should be
0-indexed and integer encoded (not one-hot encoded).
Y : array-like of shape (n_samples, n_features_structural), default=None
List of n_features-dimensional data points to fit the structural model. Each row
corresponds to a single data point. If the data is categorical, by default it should be
0-indexed and integer encoded (not one-hot encoded).
Returns
-------
predictions : array, shape (n_samples, n_columns)
Y predictions.
"""
if not hasattr(self, "_sm"):
raise ValueError("Calling predict_Y requires a structural model.")
check_is_fitted(self)
X, Y = self._check_x_y(X, Y)

_, log_resp = self._e_step(X, Y=Y)

return self._sm.predict(log_resp)

def predict_proba_Y(self, X, Y=None):
"""Call the predict method of the structural model to predict the full conditional P(Y|X).
Parameters
----------
X : array-like of shape (n_samples, n_features)
List of n_features-dimensional data points to fit the measurement model. Each row
corresponds to a single data point. If the data is categorical, by default it should be
0-indexed and integer encoded (not one-hot encoded).
Y : array-like of shape (n_samples, n_features_structural), default=None
List of n_features-dimensional data points to fit the structural model. Each row
corresponds to a single data point. If the data is categorical, by default it should be
0-indexed and integer encoded (not one-hot encoded).
Returns
-------
conditional : array, shape (n_samples, n_columns)
P(Y|X).
"""
if not hasattr(self, "_sm"):
raise ValueError("Calling predict_proba_Y requires a structural model.")
check_is_fitted(self)
X, Y = self._check_x_y(X, Y)

_, log_resp = self._e_step(X, Y=Y)

return self._sm.predict_proba(log_resp)

def sample(self, n_samples, labels=None):
"""Sample method for fitted StepMix model.
Expand Down

0 comments on commit 5a76e9d

Please sign in to comment.