Skip to content

Commit

Permalink
Categorical predict and predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
sachaMorin committed Dec 28, 2023
1 parent 7d4c16c commit 3e551bb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
80 changes: 76 additions & 4 deletions stepmix/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def data_generation_gaussian(n_samples, sep_level, n_mm=6, random_state=None):

return X, Y, labels

def data_gaussian_binary(n_samples, n_mm=6, random_state=None):
def data_gaussian_binary(n_samples, random_state=None):
"""Full Gaussian measurement model with 2 binary responses.
The data has 4 latent classes.
Expand All @@ -322,7 +322,6 @@ def data_gaussian_binary(n_samples, n_mm=6, random_state=None):
"""
n_classes = 4 # Number of latent classes
n_sm = 2 # Dimensions of the response variable Zo

# True parameters
# rho[c] = p(X=c)
Expand Down Expand Up @@ -354,8 +353,8 @@ def data_gaussian_binary(n_samples, n_mm=6, random_state=None):
weights=rho,
measurement=dict(pis=pis.T),
structural=dict(means=mus, covariances=sigmas),
measurement_in=n_mm,
structural_in=n_sm,
measurement_in=2,
structural_in=2,
)

# Sample data
Expand All @@ -370,6 +369,79 @@ def data_gaussian_binary(n_samples, n_mm=6, random_state=None):

return X, Y, labels

def data_gaussian_categorical(n_samples, random_state=None):
"""Full Gaussian measurement model with 2 categorical responses.
The data has 4 latent classes.
Parameters
----------
n_samples : int
Number of samples.
random_state: int
Random state.
Returns
-------
X : ndarray of shape (n_samples, 2)
Gaussian Measurement samples.
Y : ndarray of shape (n_samples, 2)
Categorical Structural samples.
labels : ndarray of shape (n_samples,)
Ground truth class membership.
"""
n_classes = 4 # Number of latent classes

# True parameters
# rho[c] = p(X=c)
rho = np.ones(n_classes) / n_classes

# mus[k] = E[Z0|X=c]
# mus[k] = E[Z0|X=c]
mus = np.array(
[[-2.0344, 4.1726], [3.9779, 3.7735], [3.8007, -3.7972], [-3.0620, -3.5345]]
)

# sigmas[k] = V[Z0|X=c]
sigmas = np.array(
[
[[2.9044, 0.2066], [0.2066, 2.7562]],
[[0.2104, 0.2904], [0.2904, 12.2392]],
[[0.9213, 0.0574], [0.0574, 1.8660]],
[[6.2414, 6.0502], [6.0502, 6.1825]],
]
)

# 4 classes x (2 categorical variables x 3 outcomes) = 4 x 6
pis = np.array([
[0.8, 0.1, 0.1, 0.9, 0.1, 0.0],
[0.1, 0.1, 0.8, 0.9, 0.0, 0.1],
[0.1, 0.8, 0.1, 0.1, 0.9, 0.0],
[0.1, 0.1, 0.8, 0.1, 0.0, 0.9],
])

# Model parameters
params = dict(
weights=rho,
measurement=dict(means=mus, covariances=sigmas),
structural=dict(pis=pis, max_n_outcomes=3, total_outcomes=2),
measurement_in=2,
structural_in=2,
)

# Sample data
generator = StepMix(
n_components=n_classes,
measurement="gaussian_full",
structural="categorical",
random_state=random_state,
)
generator.set_parameters(params)
X, Y, labels = generator.sample(n_samples)

return X, Y, labels

def data_gaussian_diag(n_samples, sep_level, n_mm=6, random_state=None, nan_ratio=0.0):
"""Bakk binary measurement model with 2D diagonal gaussian structural model.
Expand Down
15 changes: 15 additions & 0 deletions stepmix/emission/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,21 @@ def log_likelihood(self, X):
pis = np.clip(self.parameters["pis"].T, 1e-15, 1 - 1e-15)
log_eps = X @ np.log(pis)
return log_eps
def predict_proba(self, log_resp):
n_samples, n_features, n_outcomes = log_resp.shape[0], self.get_n_features(), self.parameters["max_n_outcomes"]
resp = np.exp(log_resp)
pis = self.parameters["pis"].reshape((self.n_components, n_features, n_outcomes))
probs = np.einsum('nk,kfo->nfo', resp, pis)
probs = probs.reshape((n_samples, n_features * n_outcomes))

return probs

def predict(self, log_resp):
n_samples, n_features, n_outcomes = log_resp.shape[0], self.get_n_features(), self.parameters["max_n_outcomes"]
probs = self.predict_proba(log_resp)
probs = probs.reshape((n_samples, n_features, n_outcomes))
return probs.argmax(axis=2).flatten()


def sample(self, class_no, n_samples):
pis = self.parameters["pis"].T
Expand Down

0 comments on commit 3e551bb

Please sign in to comment.