Skip to content

Commit 75dc290

Browse files
committed
Update mccall_model.md
1 parent db5a05f commit 75dc290

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

lectures/mccall_model.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr.
3434

3535
In addition to what's in Anaconda, this lecture will need the following libraries:
3636

37-
```{code-cell} ipython
37+
```{code-cell} ipython3
3838
---
3939
tags: [hide-output]
4040
---
41-
!pip install quantecon
41+
!pip install quantecon jax
4242
```
4343

4444
## Overview
@@ -64,6 +64,7 @@ import matplotlib.pyplot as plt
6464
import numpy as np
6565
import jax
6666
import jax.numpy as jnp
67+
import jax.random as jr
6768
from typing import NamedTuple
6869
import quantecon as qe
6970
from quantecon.distributions import BetaBinomial
@@ -367,7 +368,6 @@ plt.show()
367368

368369
We are going to use JAX to accelerate our code.
369370

370-
* JAX provides automatic differentiation and JIT compilation capabilities.
371371
* We'll use NamedTuple for our model class to maintain immutability, which works well with JAX's functional programming paradigm.
372372

373373
Here's a class that stores the data and computes the values of state-action pairs,
@@ -455,7 +455,7 @@ def compute_reservation_wage(mcm, max_iter=500, tol=1e-6):
455455
# Simplify names
456456
c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q
457457
458-
# == First compute the value function == #
458+
# First compute the value function
459459
n = len(w)
460460
v = w / (1 - β) # initial guess
461461
@@ -474,7 +474,7 @@ def compute_reservation_wage(mcm, max_iter=500, tol=1e-6):
474474
initial_state = (v, 0, tol + 1)
475475
v_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
476476
477-
# == Now compute the reservation wage == #
477+
# Now compute the reservation wage
478478
return (1 - β) * (c + β * (v_final @ q))
479479
```
480480

@@ -606,7 +606,7 @@ def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5):
606606
# Simplify names
607607
c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q
608608
609-
# == First compute h == #
609+
# First compute h
610610
h = (w @ q) / (1 - β)
611611
612612
def body_fun(state):
@@ -623,7 +623,7 @@ def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5):
623623
initial_state = (h, 0, tol + 1)
624624
h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
625625
626-
# == Now compute the reservation wage == #
626+
# Now compute the reservation wage
627627
return (1 - β) * h_final
628628
```
629629

@@ -660,8 +660,8 @@ cdf = jnp.cumsum(q_default)
660660
def compute_stopping_time(w_bar, key):
661661
def body_fun(state):
662662
t, key, done = state
663-
key, subkey = jax.random.split(key)
664-
u = jax.random.uniform(subkey)
663+
key, subkey = jr.split(key)
664+
u = jr.uniform(subkey)
665665
w = w_default[jnp.searchsorted(cdf, u)]
666666
done = w >= w_bar
667667
t = jnp.where(done, t, t + 1)
@@ -676,9 +676,9 @@ def compute_stopping_time(w_bar, key):
676676
return t_final
677677
678678
@jax.jit
679-
def compute_mean_stopping_time(w_bar, num_reps=100000, seed=0):
680-
key = jax.random.PRNGKey(seed)
681-
keys = jax.random.split(key, num_reps)
679+
def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234):
680+
key = jr.PRNGKey(seed)
681+
keys = jr.split(key, num_reps)
682682
obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys)
683683
return jnp.mean(obs)
684684
@@ -777,9 +777,9 @@ class McCallModelContinuous(NamedTuple):
777777
μ: float # location parameter in lognormal distribution
778778
w_draws: jnp.ndarray # draws of wages for Monte Carlo
779779
780-
def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=0):
781-
key = jax.random.PRNGKey(seed)
782-
s = jax.random.normal(key, (mc_size,))
780+
def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234):
781+
key = jr.PRNGKey(seed)
782+
s = jr.normal(key, (mc_size,))
783783
w_draws = jnp.exp(μ + σ * s)
784784
return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws)
785785
@@ -803,7 +803,7 @@ def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5):
803803
initial_state = (h, 0, tol + 1)
804804
h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
805805
806-
# == Now compute the reservation wage == #
806+
# Now compute the reservation wage
807807
return (1 - β) * h_final
808808
```
809809

0 commit comments

Comments
 (0)