Skip to content

Commit

Permalink
Added change of variable for both parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Jun 14, 2023
1 parent 97c9605 commit dcd2608
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions tweetopic/bayesian/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,6 @@ def symmetric_dirichlet_multinomial_mean(alpha: float, n: int, K: int):
return np.full(K, n * alpha / K * alpha)


def init_parameters(
n_docs: int, n_vocab: int, n_components: int, alpha: float, beta: float
) -> dict:
"""Initializes the parameters of the dmm to the mean of the prior."""
return dict(
weights=symmetric_dirichlet_multinomial_mean(
alpha, n_docs, n_components
),
components=np.broadcast_to(
scipy.stats.dirichlet.mean(np.full(n_vocab, beta)),
(n_components, n_vocab),
),
)


def sparse_multinomial_logpdf(
component,
unique_words,
Expand All @@ -60,12 +45,8 @@ def sparse_multinomial_logpdf(
def symmetric_dirichlet_logpdf(x, alpha):
"""Logdensity of a symmetric Dirichlet."""
K = x.shape[0]
sums_to_one = jnp.abs(1 - jnp.sum(x)) <= 0.001
all_bigger_than_zero = jnp.all(x >= 0)
return (
jnp.log(sums_to_one)
+ jnp.log(all_bigger_than_zero)
+ jax.lax.lgamma(alpha * K)
jax.lax.lgamma(alpha * K)
- K * jax.lax.lgamma(alpha)
+ (alpha - 1) * jnp.sum(jnp.nan_to_num(jnp.log(x)))
)
Expand Down Expand Up @@ -122,10 +103,45 @@ def posterior_predictive(
return predict_all(doc_unique_words, doc_unique_word_counts)


def init_parameters(
n_docs: int, n_vocab: int, n_components: int, alpha: float, beta: float
) -> dict:
"""Initializes the parameters of the dmm to the mean of the prior."""
return dict(
weights=symmetric_dirichlet_multinomial_mean(
alpha, n_docs, n_components
),
components=np.broadcast_to(
scipy.stats.dirichlet.mean(np.full(n_vocab, beta)),
(n_components, n_vocab),
),
)


def transform_components(components):
"""Transforms the components parameter, so that the sampling
space is unconstrained, this is great for HMC and SGMCMC."""
squared_components = jnp.square(components)
return (squared_components.T / jnp.sum(squared_components, axis=1)).T


def transform_weights(weights, n_docs):
"""Transforms the weights parameter, so that the sampling
space is unconstrained, this is great for HMC and SGMCMC."""
sq_weights = jnp.square(weights)
norm_weights = sq_weights / jnp.sum(sq_weights)
scaled_weights = norm_weights * n_docs
return jnp.round(scaled_weights)


def dmm_loglikelihood(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
components, weights, doc_unique_words, doc_unique_word_counts
):
"""Loglikelihood of the dirichlet multinomial mixture model."""
docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1)
n_docs = doc_unique_words.shape[0]
components = transform_components(components)
weights = transform_weights(weights, n_docs)

def doc_likelihood(doc):
unique_words, unique_word_counts = doc
Expand All @@ -144,6 +160,9 @@ def doc_likelihood(doc):


def dmm_logprior(components, weights, alpha, beta, n_docs):
"""Logprior of the dirichlet multinomial mixture model."""
components = transform_components(components)
weights = transform_weights(weights, n_docs)
components_prior = jnp.sum(
jax.lax.map(
partial(symmetric_dirichlet_logpdf, alpha=alpha), components
Expand All @@ -165,8 +184,6 @@ def dmm_logpdf(
weights,
doc_unique_words,
doc_unique_word_counts,
alpha,
beta,
)
logprior = dmm_logprior(components, weights, alpha, beta, n_docs)
return logprior + loglikelihood
Expand Down

0 comments on commit dcd2608

Please sign in to comment.