Skip to content

Commit

Permalink
Experimental Bayesian estimation with numerical integration
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztofrusek committed Oct 2, 2023
1 parent e2bf27f commit 90f40ac
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from jax import config

config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
from scipy.integrate import dblquad

from gsd import log_prob

if __name__ == '__main__':
data = jnp.asarray([5, 12, 3, 0, 0])
k = jnp.arange(1, 6)


@jax.jit
def posterior(psi, rho):
log_posterior = jax.vmap(log_prob, in_axes=(None, None, 0))(psi, rho, k) @ data + 1. + 1 / 4.
posterior = jnp.exp(log_posterior)
return posterior


epsabs = 1e-14
epsreal = 1e-11

Z, Zerr = dblquad(posterior, a=0, b=1, gfun=lambda x: 1., hfun=lambda x: 5., epsabs=epsabs, epsrel=epsreal)
psi_hat, _ = dblquad(jax.jit(lambda psi, rho: psi * posterior(psi, rho)), a=0, b=1, gfun=lambda x: 1.,
hfun=lambda x: 5.,
epsabs=epsabs, epsrel=epsreal)
psi_hat = psi_hat / Z
rho_hat, _ = dblquad(jax.jit(lambda psi, rho: rho * posterior(psi, rho)), a=0, b=1, gfun=lambda x: 1.,
hfun=lambda x: 5.,
epsabs=epsabs, epsrel=epsreal)
rho_hat = rho_hat / Z

psi_ci, _ = dblquad(jax.jit(lambda psi, rho: (psi_hat - psi) ** 2 * posterior(psi, rho)), a=0, b=1,
gfun=lambda x: 1., hfun=lambda x: 5.,
epsabs=epsabs, epsrel=epsreal)

psi_ci = np.sqrt(psi_ci / Z)

rho_ci, _ = dblquad(jax.jit(lambda psi, rho: (rho_hat - rho) ** 2 * posterior(psi, rho)), a=0, b=1,
gfun=lambda x: 1., hfun=lambda x: 5.,
epsabs=epsabs, epsrel=epsreal)

rho_ci = np.sqrt(rho_ci / Z)

k @ data / data.sum()
pass

1 comment on commit 90f40ac

@krzysztofrusek
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucjanJanowski FYI, we can estimate GSD with uncertainty in just few seconds (simple model only).

Please sign in to comment.