@@ -81,16 +81,18 @@ We first take a look at the density functions `f` and `g` .
8181class ImpSampleParams(NamedTuple):
8282 F_a: float = 1.0 # Beta parameters for f
8383 F_b: float = 1.0
84- G_a: float = 3.0 # Beta parameters for g
84+ G_a: float = 3.0 # Beta parameters for g
8585 G_b: float = 1.2
8686
8787params = ImpSampleParams()
8888
8989@jax.jit
9090def beta_pdf(w, a, b):
9191 """Beta probability density function."""
92- log_beta_const = gammaln(a) + gammaln(b) - gammaln(a + b)
93- log_pdf = (a - 1) * jnp.log(w) + (b - 1) * jnp.log(1 - w) - log_beta_const
92+ log_beta_const = (gammaln(a) + gammaln(b) -
93+ gammaln(a + b))
94+ log_pdf = ((a - 1) * jnp.log(w) + (b - 1) *
95+ jnp.log(1 - w) - log_beta_const)
9496 return jnp.exp(log_pdf)
9597
9698@jax.jit
@@ -194,8 +196,10 @@ mystnb:
194196---
195197w_range = jnp.linspace(1e-5, 1-1e-5, 1000)
196198
197- plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})')
198- plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5), label=f'h=Beta({h_a}, {h_b})')
199+ plt.plot(w_range, g(w_range),
200+ label=f'g=Beta({g_a}, {g_b})')
201+ plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5),
202+ label=f'h=Beta({h_a}, {h_b})')
199203plt.legend()
200204plt.ylim([0., 3.])
201205plt.show()
@@ -230,18 +234,18 @@ def estimate_single_path(key, p_a, p_b, q_a, q_b, T):
230234 L, weight, key_state = carry
231235 key_state, subkey = jr.split(key_state)
232236 w = jr.beta(subkey, q_a, q_b)
233-
237+
234238 # Compute likelihood ratio using f/g functions
235239 likelihood_ratio = f(w) / g(w)
236240 L = L * likelihood_ratio
237-
241+
238242 # Importance sampling weight with beta_pdf
239243 p_w = beta_pdf(w, p_a, p_b)
240244 q_w = beta_pdf(w, q_a, q_b)
241245 weight = weight * (p_w / q_w)
242-
246+
243247 return (L, weight, key_state)
244-
248+
245249 # Use fori_loop for dynamic T values
246250 final_L, final_weight, _ = jax.lax.fori_loop(
247251 0, T, loop_body, (1.0, 1.0, key)
@@ -252,13 +256,13 @@ def estimate_single_path(key, p_a, p_b, q_a, q_b, T):
252256def estimate(key, p_a, p_b, q_a, q_b, T=1, N=10000):
253257 """Estimation of a batch of sample paths."""
254258 keys = jr.split(key, N)
255-
259+
256260 # Use vmap for vectorized computation
257261 estimates = jax.vmap(
258- estimate_single_path,
262+ estimate_single_path,
259263 in_axes=(0, *[None]*5)
260264 )(keys, p_a, p_b, q_a, q_b, T)
261-
265+
262266 return jnp.mean(estimates)
263267```
264268
@@ -313,21 +317,24 @@ The code below produces distributions of estimates using both Monte Carlo and i
313317
314318``` {code-cell} ipython3
315319@partial(jax.jit, static_argnames=['N_simu', 'N_samples'])
316- def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1, N_samples=1000):
320+ def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1,
321+ N_samples=1000):
317322 """Simulation for both Monte Carlo and importance sampling."""
318323 keys = jr.split(key, 2 * N_simu)
319324 keys_p = keys[:N_simu]
320325 keys_q = keys[N_simu:]
321-
326+
322327 def run_monte_carlo(key_batch):
323- return estimate(key_batch, p_a, p_b, p_a, p_b, T, N_samples)
324-
328+ return estimate(key_batch, p_a, p_b, p_a, p_b, T,
329+ N_samples)
330+
325331 def run_importance_sampling(key_batch):
326- return estimate(key_batch, p_a, p_b, q_a, q_b, T, N_samples)
327-
332+ return estimate(key_batch, p_a, p_b, q_a, q_b, T,
333+ N_samples)
334+
328335 μ_L_p = jax.vmap(run_monte_carlo)(keys_p)
329336 μ_L_q = jax.vmap(run_importance_sampling)(keys_q)
330-
337+
331338 return μ_L_p, μ_L_q
332339```
333340
@@ -358,25 +365,31 @@ Next, we present distributions of estimates for $\hat{E} \left[L\left(\omega^t\r
358365``` {code-cell} ipython3
359366T_values = [1, 5, 10, 20]
360367
361- def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu, T_list, N_samples=1000):
368+ def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu,
369+ T_list, N_samples=1000):
362370 """Simulation for multiple T values."""
363371 n_T = len(T_list)
364372 keys = jr.split(key, n_T)
365-
373+
366374 results = []
367375 for i, T in enumerate(T_list):
368- result = simulate(keys[i], p_a, p_b, q_a, q_b, N_simu, T, N_samples)
376+ result = simulate(keys[i],
377+ p_a, p_b, q_a, q_b, N_simu, T,
378+ N_samples)
369379 results.append(result)
370380
371381 # Stack results into arrays for consistency
372382 μ_L_p_all = jnp.stack([r[0] for r in results])
373383 μ_L_q_all = jnp.stack([r[1] for r in results])
374-
384+
375385 return μ_L_p_all, μ_L_q_all
376386
377387# Run all simulations at once
378388key, subkey = jr.split(key)
379- all_results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values, N_samples=1000)
389+ all_results = simulate_multiple_T(subkey,
390+ g_a, g_b, h_a, h_b,
391+ N_simu, T_values,
392+ N_samples=1000)
380393
381394# Extract results
382395μ_L_p_all, μ_L_q_all = all_results
@@ -387,25 +400,36 @@ fig, axs = plt.subplots(2, 2, figsize=(14, 10))
387400for i, t in enumerate(T_values):
388401 row = i // 2
389402 col = i % 2
390-
403+
391404 # Get results for this T value
392405 μ_L_p = μ_L_p_all[i]
393406 μ_L_q = μ_L_q_all[i]
394-
395- μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q)
396- σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q)
407+
408+ μ_hat_p = jnp.nanmean(μ_L_p)
409+ μ_hat_q = jnp.nanmean(μ_L_q)
410+ σ_hat_p = jnp.nanvar(μ_L_p)
411+ σ_hat_q = jnp.nanvar(μ_L_q)
397412
398413 axs[row, col].set_xlabel('$μ_L$')
399414 axs[row, col].set_ylabel('frequency')
400415 axs[row, col].set_title(f'$T$={t}')
401- n_p, bins_p, _ = axs[row, col].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating')
402- n_q, bins_q, _ = axs[row, col].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h$ generating')
416+ n_p, bins_p, _ = axs[row, col].hist(
417+ μ_L_p, bins=μ_range,
418+ color='r', alpha=0.5, label='$g$ generating')
419+ n_q, bins_q, _ = axs[row, col].hist(
420+ μ_L_q, bins=μ_range,
421+ color='b', alpha=0.5, label='$h$ generating')
403422 axs[row, col].legend(loc=4)
404423
405- for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p],
406- [n_q, bins_q, μ_hat_q, σ_hat_q]]:
424+ for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p,
425+ σ_hat_p],
426+ [n_q, bins_q, μ_hat_q,
427+ σ_hat_q]]:
407428 idx = jnp.argmax(n)
408- axs[row, col].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}')
429+ axs[row, col].text(
430+ bins[idx], n[idx],
431+ r'$\hat{μ}$=' + f'{μ_hat:.4g}' +
432+ r', $\hat{σ}=$' + f'{σ_hat:.4g}')
409433
410434plt.show()
411435```
434458
435459``` {code-cell} ipython3
436460key, subkey = jr.split(key)
437- μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a, params.F_b, N_simu)
461+ μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a,
462+ params.F_b, N_simu)
438463```
439464
440465``` {code-cell} ipython3
@@ -454,10 +479,14 @@ b_list = [0.5, 1.2, 5.]
454479``` {code-cell} ipython3
455480w_range = jnp.linspace(1e-5, 1-1e-5, 1000)
456481
457- plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})')
458- plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]), label=f'$h_1$=Beta({a_list[0]},{b_list[0]})')
459- plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]), label=f'$h_2$=Beta({a_list[1]},{b_list[1]})')
460- plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]), label=f'$h_3$=Beta({a_list[2]},{b_list[2]})')
482+ plt.plot(w_range, g(w_range),
483+ label=f'g=Beta({g_a}, {g_b})')
484+ plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]),
485+ label=f'$h_1$=Beta({a_list[0]},{b_list[0]})')
486+ plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]),
487+ label=f'$h_2$=Beta({a_list[1]},{b_list[1]})')
488+ plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]),
489+ label=f'$h_3$=Beta({a_list[2]},{b_list[2]})')
461490plt.legend()
462491plt.ylim([0., 3.])
463492plt.show()
@@ -488,30 +517,44 @@ h_b = b_list[1]
488517
489518T_values_h2 = [1, 20]
490519key, subkey = jr.split(key)
491- all_results_h2 = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values_h2, N_samples=1000)
520+ all_results_h2 = simulate_multiple_T(subkey,
521+ g_a, g_b, h_a, h_b,
522+ N_simu, T_values_h2,
523+ N_samples=1000)
492524μ_L_p_all_h2, μ_L_q_all_h2 = all_results_h2
493525
494- fig, axs = plt.subplots(1,2, figsize=(14, 10))
526+ fig, axs = plt.subplots(1, 2, figsize=(14, 10))
495527μ_range = jnp.linspace(0, 2, 100)
496528
497529for i, t in enumerate(T_values_h2):
498530 μ_L_p = μ_L_p_all_h2[i]
499531 μ_L_q = μ_L_q_all_h2[i]
500-
501- μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q)
502- σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q)
532+
533+ μ_hat_p = jnp.nanmean(μ_L_p)
534+ μ_hat_q = jnp.nanmean(μ_L_q)
535+ σ_hat_p = jnp.nanvar(μ_L_p)
536+ σ_hat_q = jnp.nanvar(μ_L_q)
503537
504538 axs[i].set_xlabel('$μ_L$')
505539 axs[i].set_ylabel('frequency')
506540 axs[i].set_title(f'$T$={t}')
507- n_p, bins_p, _ = axs[i].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating')
508- n_q, bins_q, _ = axs[i].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h_2$ generating')
541+ n_p, bins_p, _ = axs[i].hist(
542+ μ_L_p, bins=μ_range,
543+ color='r', alpha=0.5, label='$g$ generating')
544+ n_q, bins_q, _ = axs[i].hist(
545+ μ_L_q, bins=μ_range,
546+ color='b', alpha=0.5, label='$h_2$ generating')
509547 axs[i].legend(loc=4)
510548
511- for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p],
512- [n_q, bins_q, μ_hat_q, σ_hat_q]]:
549+ for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p,
550+ σ_hat_p],
551+ [n_q, bins_q, μ_hat_q,
552+ σ_hat_q]]:
513553 idx = jnp.argmax(n)
514- axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}')
554+ axs[i].text(
555+ bins[idx], n[idx],
556+ r'$\hat{μ}$=' + f'{μ_hat:.4g}' +
557+ r', $\hat{σ}=$' + f'{σ_hat:.4g}')
515558
516559plt.show()
517560```
@@ -526,27 +569,41 @@ h_b = b_list[2]
526569
527570T_list = [1, 20]
528571key, subkey = jr.split(key)
529- results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_list, N_samples=1000)
572+ results = simulate_multiple_T(subkey,
573+ g_a, g_b, h_a, h_b,
574+ N_simu, T_list,
575+ N_samples=1000)
530576
531577fig, axs = plt.subplots(1, 2, figsize=(14, 10))
532578μ_range = jnp.linspace(0, 2, 100)
533579
534580for i, t in enumerate(T_list):
535581 μ_L_p, μ_L_q = results[i]
536- μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q)
537- σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q)
582+ μ_hat_p = jnp.nanmean(μ_L_p)
583+ μ_hat_q = jnp.nanmean(μ_L_q)
584+ σ_hat_p = jnp.nanvar(μ_L_p)
585+ σ_hat_q = jnp.nanvar(μ_L_q)
538586
539587 axs[i].set_xlabel('$μ_L$')
540588 axs[i].set_ylabel('frequency')
541589 axs[i].set_title(f'$T$={t}')
542- n_p, bins_p, _ = axs[i].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating')
543- n_q, bins_q, _ = axs[i].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h_3$ generating')
590+ n_p, bins_p, _ = axs[i].hist(
591+ μ_L_p, bins=μ_range,
592+ color='r', alpha=0.5, label='$g$ generating')
593+ n_q, bins_q, _ = axs[i].hist(
594+ μ_L_q, bins=μ_range,
595+ color='b', alpha=0.5, label='$h_3$ generating')
544596 axs[i].legend(loc=4)
545597
546- for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p],
547- [n_q, bins_q, μ_hat_q, σ_hat_q]]:
598+ for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p,
599+ σ_hat_p],
600+ [n_q, bins_q, μ_hat_q,
601+ σ_hat_q]]:
548602 idx = jnp.argmax(n)
549- axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}')
603+ axs[i].text(
604+ bins[idx], n[idx],
605+ r'$\hat{μ}$=' + f'{μ_hat:.4g}' +
606+ r', $\hat{σ}=$' + f'{σ_hat:.4g}')
550607
551608plt.show()
552609```
0 commit comments