@@ -34,11 +34,11 @@ and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr.
3434
3535In addition to what's in Anaconda, this lecture will need the following libraries:
3636
37- ``` {code-cell} ipython
37+ ``` {code-cell} ipython3
3838---
3939tags: [hide-output]
4040---
41- !pip install quantecon
41+ !pip install quantecon jax
4242```
4343
4444## Overview
@@ -64,6 +64,7 @@ import matplotlib.pyplot as plt
6464import numpy as np
6565import jax
6666import jax.numpy as jnp
67+ import jax.random as jr
6768from typing import NamedTuple
6869import quantecon as qe
6970from quantecon.distributions import BetaBinomial
@@ -367,7 +368,6 @@ plt.show()
367368
368369We are going to use JAX to accelerate our code.
369370
370- * JAX provides automatic differentiation and JIT compilation capabilities.
371371* We'll use NamedTuple for our model class to maintain immutability, which works well with JAX's functional programming paradigm.
372372
373373Here's a class that stores the data and computes the values of state-action pairs,
@@ -455,7 +455,7 @@ def compute_reservation_wage(mcm, max_iter=500, tol=1e-6):
455455 # Simplify names
456456 c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q
457457
458- # == First compute the value function == #
458+ # First compute the value function
459459 n = len(w)
460460 v = w / (1 - β) # initial guess
461461
@@ -474,7 +474,7 @@ def compute_reservation_wage(mcm, max_iter=500, tol=1e-6):
474474 initial_state = (v, 0, tol + 1)
475475 v_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
476476
477- # == Now compute the reservation wage == #
477+ # Now compute the reservation wage
478478 return (1 - β) * (c + β * (v_final @ q))
479479```
480480
@@ -606,7 +606,7 @@ def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5):
606606 # Simplify names
607607 c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q
608608
609- # == First compute h == #
609+ # First compute h
610610 h = (w @ q) / (1 - β)
611611
612612 def body_fun(state):
@@ -623,7 +623,7 @@ def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5):
623623 initial_state = (h, 0, tol + 1)
624624 h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
625625
626- # == Now compute the reservation wage == #
626+ # Now compute the reservation wage
627627 return (1 - β) * h_final
628628```
629629
@@ -660,8 +660,8 @@ cdf = jnp.cumsum(q_default)
660660def compute_stopping_time(w_bar, key):
661661 def body_fun(state):
662662 t, key, done = state
663- key, subkey = jax.random .split(key)
664- u = jax.random .uniform(subkey)
663+ key, subkey = jr .split(key)
664+ u = jr .uniform(subkey)
665665 w = w_default[jnp.searchsorted(cdf, u)]
666666 done = w >= w_bar
667667 t = jnp.where(done, t, t + 1)
@@ -676,9 +676,9 @@ def compute_stopping_time(w_bar, key):
676676 return t_final
677677
678678@jax.jit
679- def compute_mean_stopping_time(w_bar, num_reps=100000, seed=0 ):
680- key = jax.random .PRNGKey(seed)
681- keys = jax.random .split(key, num_reps)
679+ def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234 ):
680+ key = jr .PRNGKey(seed)
681+ keys = jr .split(key, num_reps)
682682 obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys)
683683 return jnp.mean(obs)
684684
@@ -777,9 +777,9 @@ class McCallModelContinuous(NamedTuple):
777777 μ: float # location parameter in lognormal distribution
778778 w_draws: jnp.ndarray # draws of wages for Monte Carlo
779779
780- def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=0 ):
781- key = jax.random .PRNGKey(seed)
782- s = jax.random .normal(key, (mc_size,))
780+ def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234 ):
781+ key = jr .PRNGKey(seed)
782+ s = jr .normal(key, (mc_size,))
783783 w_draws = jnp.exp(μ + σ * s)
784784 return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws)
785785
@@ -803,7 +803,7 @@ def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5):
803803 initial_state = (h, 0, tol + 1)
804804 h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
805805
806- # == Now compute the reservation wage == #
806+ # Now compute the reservation wage
807807 return (1 - β) * h_final
808808```
809809
0 commit comments