Skip to content

Commit 5994ba6

Browse files
Add support for fit from list
1 parent 40db0c8 commit 5994ba6

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/gsd/fit.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def fit_moments(data: ArrayLike) -> GSDParams:
2727
:param data: A 5d Array of counts of each response.
2828
:return: GSD Parameters
2929
"""
30+
31+
data= jnp.asarray(data)
3032
psi = jnp.dot(data, jnp.arange(1, 6)) / jnp.sum(data)
3133
V = jnp.dot(data, jnp.arange(1, 6) ** 2) / jnp.sum(data) - psi ** 2
3234
return GSDParams(psi=psi, rho=(vmax(psi) - V) / (vmax(psi) - vmin(psi)))
@@ -65,6 +67,7 @@ def fit_mle(data: ArrayLike, max_iterations: int = 100, log_lr_min: ArrayLike =
6567
:return: An opt state whore params filed contains estimated values of GSD Parameters
6668
"""
6769

70+
data = jnp.asarray(data)
6871
def ll(theta: GSDParams) -> Array:
6972
logits = jax.vmap(log_prob, (None, None, 0), (0))(theta.psi, theta.rho, jnp.arange(1, 6))
7073
return jnp.dot(data, logits) / jnp.sum(data)

tests/fit_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ def test_mle(self):
1313
self.assertAlmostEqual(os.params.rho, 1)
1414

1515

16+
def test_list(self):
17+
# 1 2 3 4 5
18+
data=[0,10,10,0,0.]
19+
_,os = gsd.fit.fit_mle(data)
20+
self.assertAlmostEqual(os.params.psi, 2.5)
21+
self.assertAlmostEqual(os.params.rho, 1)
22+
23+
1624

1725
if __name__ == '__main__':
1826
unittest.main()

0 commit comments

Comments
 (0)