Skip to content

Commit 58d1cd2

Browse files
committed
update pep8
1 parent e57deef commit 58d1cd2

File tree

1 file changed

+113
-56
lines changed

1 file changed

+113
-56
lines changed

lectures/imp_sample.md

Lines changed: 113 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,18 @@ We first take a look at the density functions `f` and `g` .
8181
class ImpSampleParams(NamedTuple):
8282
F_a: float = 1.0 # Beta parameters for f
8383
F_b: float = 1.0
84-
G_a: float = 3.0 # Beta parameters for g
84+
G_a: float = 3.0 # Beta parameters for g
8585
G_b: float = 1.2
8686
8787
params = ImpSampleParams()
8888
8989
@jax.jit
9090
def beta_pdf(w, a, b):
9191
"""Beta probability density function."""
92-
log_beta_const = gammaln(a) + gammaln(b) - gammaln(a + b)
93-
log_pdf = (a - 1) * jnp.log(w) + (b - 1) * jnp.log(1 - w) - log_beta_const
92+
log_beta_const = (gammaln(a) + gammaln(b) -
93+
gammaln(a + b))
94+
log_pdf = ((a - 1) * jnp.log(w) + (b - 1) *
95+
jnp.log(1 - w) - log_beta_const)
9496
return jnp.exp(log_pdf)
9597
9698
@jax.jit
@@ -194,8 +196,10 @@ mystnb:
194196
---
195197
w_range = jnp.linspace(1e-5, 1-1e-5, 1000)
196198
197-
plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})')
198-
plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5), label=f'h=Beta({h_a}, {h_b})')
199+
plt.plot(w_range, g(w_range),
200+
label=f'g=Beta({g_a}, {g_b})')
201+
plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5),
202+
label=f'h=Beta({h_a}, {h_b})')
199203
plt.legend()
200204
plt.ylim([0., 3.])
201205
plt.show()
@@ -230,18 +234,18 @@ def estimate_single_path(key, p_a, p_b, q_a, q_b, T):
230234
L, weight, key_state = carry
231235
key_state, subkey = jr.split(key_state)
232236
w = jr.beta(subkey, q_a, q_b)
233-
237+
234238
# Compute likelihood ratio using f/g functions
235239
likelihood_ratio = f(w) / g(w)
236240
L = L * likelihood_ratio
237-
241+
238242
# Importance sampling weight with beta_pdf
239243
p_w = beta_pdf(w, p_a, p_b)
240244
q_w = beta_pdf(w, q_a, q_b)
241245
weight = weight * (p_w / q_w)
242-
246+
243247
return (L, weight, key_state)
244-
248+
245249
# Use fori_loop for dynamic T values
246250
final_L, final_weight, _ = jax.lax.fori_loop(
247251
0, T, loop_body, (1.0, 1.0, key)
@@ -252,13 +256,13 @@ def estimate_single_path(key, p_a, p_b, q_a, q_b, T):
252256
def estimate(key, p_a, p_b, q_a, q_b, T=1, N=10000):
253257
"""Estimation of a batch of sample paths."""
254258
keys = jr.split(key, N)
255-
259+
256260
# Use vmap for vectorized computation
257261
estimates = jax.vmap(
258-
estimate_single_path,
262+
estimate_single_path,
259263
in_axes=(0, *[None]*5)
260264
)(keys, p_a, p_b, q_a, q_b, T)
261-
265+
262266
return jnp.mean(estimates)
263267
```
264268

@@ -313,21 +317,24 @@ The code below produces distributions of estimates using both Monte Carlo and i
313317

314318
```{code-cell} ipython3
315319
@partial(jax.jit, static_argnames=['N_simu', 'N_samples'])
316-
def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1, N_samples=1000):
320+
def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1,
321+
N_samples=1000):
317322
"""Simulation for both Monte Carlo and importance sampling."""
318323
keys = jr.split(key, 2 * N_simu)
319324
keys_p = keys[:N_simu]
320325
keys_q = keys[N_simu:]
321-
326+
322327
def run_monte_carlo(key_batch):
323-
return estimate(key_batch, p_a, p_b, p_a, p_b, T, N_samples)
324-
328+
return estimate(key_batch, p_a, p_b, p_a, p_b, T,
329+
N_samples)
330+
325331
def run_importance_sampling(key_batch):
326-
return estimate(key_batch, p_a, p_b, q_a, q_b, T, N_samples)
327-
332+
return estimate(key_batch, p_a, p_b, q_a, q_b, T,
333+
N_samples)
334+
328335
μ_L_p = jax.vmap(run_monte_carlo)(keys_p)
329336
μ_L_q = jax.vmap(run_importance_sampling)(keys_q)
330-
337+
331338
return μ_L_p, μ_L_q
332339
```
333340

@@ -358,25 +365,31 @@ Next, we present distributions of estimates for $\hat{E} \left[L\left(\omega^t\r
358365
```{code-cell} ipython3
359366
T_values = [1, 5, 10, 20]
360367
361-
def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu, T_list, N_samples=1000):
368+
def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu,
369+
T_list, N_samples=1000):
362370
"""Simulation for multiple T values."""
363371
n_T = len(T_list)
364372
keys = jr.split(key, n_T)
365-
373+
366374
results = []
367375
for i, T in enumerate(T_list):
368-
result = simulate(keys[i], p_a, p_b, q_a, q_b, N_simu, T, N_samples)
376+
result = simulate(keys[i],
377+
p_a, p_b, q_a, q_b, N_simu, T,
378+
N_samples)
369379
results.append(result)
370380
371381
# Stack results into arrays for consistency
372382
μ_L_p_all = jnp.stack([r[0] for r in results])
373383
μ_L_q_all = jnp.stack([r[1] for r in results])
374-
384+
375385
return μ_L_p_all, μ_L_q_all
376386
377387
# Run all simulations at once
378388
key, subkey = jr.split(key)
379-
all_results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values, N_samples=1000)
389+
all_results = simulate_multiple_T(subkey,
390+
g_a, g_b, h_a, h_b,
391+
N_simu, T_values,
392+
N_samples=1000)
380393
381394
# Extract results
382395
μ_L_p_all, μ_L_q_all = all_results
@@ -387,25 +400,36 @@ fig, axs = plt.subplots(2, 2, figsize=(14, 10))
387400
for i, t in enumerate(T_values):
388401
row = i // 2
389402
col = i % 2
390-
403+
391404
# Get results for this T value
392405
μ_L_p = μ_L_p_all[i]
393406
μ_L_q = μ_L_q_all[i]
394-
395-
μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q)
396-
σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q)
407+
408+
μ_hat_p = jnp.nanmean(μ_L_p)
409+
μ_hat_q = jnp.nanmean(μ_L_q)
410+
σ_hat_p = jnp.nanvar(μ_L_p)
411+
σ_hat_q = jnp.nanvar(μ_L_q)
397412
398413
axs[row, col].set_xlabel('$μ_L$')
399414
axs[row, col].set_ylabel('frequency')
400415
axs[row, col].set_title(f'$T$={t}')
401-
n_p, bins_p, _ = axs[row, col].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating')
402-
n_q, bins_q, _ = axs[row, col].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h$ generating')
416+
n_p, bins_p, _ = axs[row, col].hist(
417+
μ_L_p, bins=μ_range,
418+
color='r', alpha=0.5, label='$g$ generating')
419+
n_q, bins_q, _ = axs[row, col].hist(
420+
μ_L_q, bins=μ_range,
421+
color='b', alpha=0.5, label='$h$ generating')
403422
axs[row, col].legend(loc=4)
404423
405-
for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p],
406-
[n_q, bins_q, μ_hat_q, σ_hat_q]]:
424+
for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p,
425+
σ_hat_p],
426+
[n_q, bins_q, μ_hat_q,
427+
σ_hat_q]]:
407428
idx = jnp.argmax(n)
408-
axs[row, col].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}')
429+
axs[row, col].text(
430+
bins[idx], n[idx],
431+
r'$\hat{μ}$=' + f'{μ_hat:.4g}' +
432+
r', $\hat{σ}=$' + f'{σ_hat:.4g}')
409433
410434
plt.show()
411435
```
@@ -434,7 +458,8 @@ $$
434458

435459
```{code-cell} ipython3
436460
key, subkey = jr.split(key)
437-
μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a, params.F_b, N_simu)
461+
μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a,
462+
params.F_b, N_simu)
438463
```
439464

440465
```{code-cell} ipython3
@@ -454,10 +479,14 @@ b_list = [0.5, 1.2, 5.]
454479
```{code-cell} ipython3
455480
w_range = jnp.linspace(1e-5, 1-1e-5, 1000)
456481
457-
plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})')
458-
plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]), label=f'$h_1$=Beta({a_list[0]},{b_list[0]})')
459-
plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]), label=f'$h_2$=Beta({a_list[1]},{b_list[1]})')
460-
plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]), label=f'$h_3$=Beta({a_list[2]},{b_list[2]})')
482+
plt.plot(w_range, g(w_range),
483+
label=f'g=Beta({g_a}, {g_b})')
484+
plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]),
485+
label=f'$h_1$=Beta({a_list[0]},{b_list[0]})')
486+
plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]),
487+
label=f'$h_2$=Beta({a_list[1]},{b_list[1]})')
488+
plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]),
489+
label=f'$h_3$=Beta({a_list[2]},{b_list[2]})')
461490
plt.legend()
462491
plt.ylim([0., 3.])
463492
plt.show()
@@ -488,30 +517,44 @@ h_b = b_list[1]
488517
489518
T_values_h2 = [1, 20]
490519
key, subkey = jr.split(key)
491-
all_results_h2 = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values_h2, N_samples=1000)
520+
all_results_h2 = simulate_multiple_T(subkey,
521+
g_a, g_b, h_a, h_b,
522+
N_simu, T_values_h2,
523+
N_samples=1000)
492524
μ_L_p_all_h2, μ_L_q_all_h2 = all_results_h2
493525
494-
fig, axs = plt.subplots(1,2, figsize=(14, 10))
526+
fig, axs = plt.subplots(1, 2, figsize=(14, 10))
495527
μ_range = jnp.linspace(0, 2, 100)
496528
497529
for i, t in enumerate(T_values_h2):
498530
μ_L_p = μ_L_p_all_h2[i]
499531
μ_L_q = μ_L_q_all_h2[i]
500-
501-
μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q)
502-
σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q)
532+
533+
μ_hat_p = jnp.nanmean(μ_L_p)
534+
μ_hat_q = jnp.nanmean(μ_L_q)
535+
σ_hat_p = jnp.nanvar(μ_L_p)
536+
σ_hat_q = jnp.nanvar(μ_L_q)
503537
504538
axs[i].set_xlabel('$μ_L$')
505539
axs[i].set_ylabel('frequency')
506540
axs[i].set_title(f'$T$={t}')
507-
n_p, bins_p, _ = axs[i].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating')
508-
n_q, bins_q, _ = axs[i].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h_2$ generating')
541+
n_p, bins_p, _ = axs[i].hist(
542+
μ_L_p, bins=μ_range,
543+
color='r', alpha=0.5, label='$g$ generating')
544+
n_q, bins_q, _ = axs[i].hist(
545+
μ_L_q, bins=μ_range,
546+
color='b', alpha=0.5, label='$h_2$ generating')
509547
axs[i].legend(loc=4)
510548
511-
for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p],
512-
[n_q, bins_q, μ_hat_q, σ_hat_q]]:
549+
for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p,
550+
σ_hat_p],
551+
[n_q, bins_q, μ_hat_q,
552+
σ_hat_q]]:
513553
idx = jnp.argmax(n)
514-
axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}')
554+
axs[i].text(
555+
bins[idx], n[idx],
556+
r'$\hat{μ}$=' + f'{μ_hat:.4g}' +
557+
r', $\hat{σ}=$' + f'{σ_hat:.4g}')
515558
516559
plt.show()
517560
```
@@ -526,27 +569,41 @@ h_b = b_list[2]
526569
527570
T_list = [1, 20]
528571
key, subkey = jr.split(key)
529-
results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_list, N_samples=1000)
572+
results = simulate_multiple_T(subkey,
573+
g_a, g_b, h_a, h_b,
574+
N_simu, T_list,
575+
N_samples=1000)
530576
531577
fig, axs = plt.subplots(1, 2, figsize=(14, 10))
532578
μ_range = jnp.linspace(0, 2, 100)
533579
534580
for i, t in enumerate(T_list):
535581
μ_L_p, μ_L_q = results[i]
536-
μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q)
537-
σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q)
582+
μ_hat_p = jnp.nanmean(μ_L_p)
583+
μ_hat_q = jnp.nanmean(μ_L_q)
584+
σ_hat_p = jnp.nanvar(μ_L_p)
585+
σ_hat_q = jnp.nanvar(μ_L_q)
538586
539587
axs[i].set_xlabel('$μ_L$')
540588
axs[i].set_ylabel('frequency')
541589
axs[i].set_title(f'$T$={t}')
542-
n_p, bins_p, _ = axs[i].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating')
543-
n_q, bins_q, _ = axs[i].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h_3$ generating')
590+
n_p, bins_p, _ = axs[i].hist(
591+
μ_L_p, bins=μ_range,
592+
color='r', alpha=0.5, label='$g$ generating')
593+
n_q, bins_q, _ = axs[i].hist(
594+
μ_L_q, bins=μ_range,
595+
color='b', alpha=0.5, label='$h_3$ generating')
544596
axs[i].legend(loc=4)
545597
546-
for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p],
547-
[n_q, bins_q, μ_hat_q, σ_hat_q]]:
598+
for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p,
599+
σ_hat_p],
600+
[n_q, bins_q, μ_hat_q,
601+
σ_hat_q]]:
548602
idx = jnp.argmax(n)
549-
axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}')
603+
axs[i].text(
604+
bins[idx], n[idx],
605+
r'$\hat{μ}$=' + f'{μ_hat:.4g}' +
606+
r', $\hat{σ}=$' + f'{σ_hat:.4g}')
550607
551608
plt.show()
552609
```

0 commit comments

Comments
 (0)