Skip to content

Commit

Permalink
Add support for fit from list
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztofrusek committed Oct 16, 2023
1 parent 40db0c8 commit 5994ba6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/gsd/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def fit_moments(data: ArrayLike) -> GSDParams:
:param data: A 5d Array of counts of each response.
:return: GSD Parameters
"""

data= jnp.asarray(data)
psi = jnp.dot(data, jnp.arange(1, 6)) / jnp.sum(data)
V = jnp.dot(data, jnp.arange(1, 6) ** 2) / jnp.sum(data) - psi ** 2
return GSDParams(psi=psi, rho=(vmax(psi) - V) / (vmax(psi) - vmin(psi)))
Expand Down Expand Up @@ -65,6 +67,7 @@ def fit_mle(data: ArrayLike, max_iterations: int = 100, log_lr_min: ArrayLike =
:return: An opt state whore params filed contains estimated values of GSD Parameters
"""

data = jnp.asarray(data)
def ll(theta: GSDParams) -> Array:
logits = jax.vmap(log_prob, (None, None, 0), (0))(theta.psi, theta.rho, jnp.arange(1, 6))
return jnp.dot(data, logits) / jnp.sum(data)
Expand Down
8 changes: 8 additions & 0 deletions tests/fit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ def test_mle(self):
self.assertAlmostEqual(os.params.rho, 1)


def test_list(self):
# 1 2 3 4 5
data=[0,10,10,0,0.]
_,os = gsd.fit.fit_mle(data)
self.assertAlmostEqual(os.params.psi, 2.5)
self.assertAlmostEqual(os.params.rho, 1)



if __name__ == '__main__':
unittest.main()

0 comments on commit 5994ba6

Please sign in to comment.