Skip to content

Commit 72ae11f

Browse files
jstacclaude
andcommitted
Improve key handling and fix parameter consistency in McCall lectures
- Refactor random key handling to use fold_in instead of key threading - More idiomatic JAX pattern for indexed loops - Removes key from loop state for cleaner code - Deterministic randomness based on time step - Fix missing n_agents variable in _simulate_cross_section_compiled - Extract from initial_wage_indices using len() - Standardize separation rate across lectures - Set α = 0.05 in mccall_fitted_vfi to match mccall_model_with_sep_markov - All economic parameters now consistent between lectures 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 977a93e commit 72ae11f

File tree

2 files changed

+45
-44
lines changed

2 files changed

+45
-44
lines changed

lectures/mccall_fitted_vfi.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class Model(NamedTuple):
268268
269269
def create_mccall_model(
270270
c: float = 1.0,
271-
α: float = 0.1,
271+
α: float = 0.05,
272272
β: float = 0.96,
273273
ρ: float = 0.9,
274274
ν: float = 0.2,
@@ -633,29 +633,29 @@ def _simulate_cross_section_compiled(
633633
c, α, β, ρ, ν, γ, w_grid, z_draws = model
634634
635635
# Initialize arrays
636-
key, subkey = jax.random.split(key)
636+
init_key, subkey = jax.random.split(key)
637637
wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
638638
status = jnp.zeros(n_agents, dtype=jnp.int32)
639639
640640
def update(t, loop_state):
641-
key, status, wages = loop_state
641+
status, wages = loop_state
642642
643643
# Shift loop state forwards
644-
key, subkey = jax.random.split(key)
645-
agent_keys = jax.random.split(subkey, n_agents)
644+
step_key = jax.random.fold_in(init_key, t)
645+
agent_keys = jax.random.split(step_key, n_agents)
646646
647647
status, wages = update_agents_vmap(
648648
agent_keys, status, wages, model, w_bar
649649
)
650650
651-
return key, status, wages
651+
return status, wages
652652
653653
# Run simulation using fori_loop
654-
initial_loop_state = (key, status, wages)
654+
initial_loop_state = (status, wages)
655655
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
656656
657657
# Return only final employment state
658-
_, final_is_employed, _ = final_loop_state
658+
final_is_employed, _ = final_loop_state
659659
return final_is_employed
660660
661661

lectures/mccall_model_with_sep_markov.md

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -751,8 +751,8 @@ Now let's simulate many agents simultaneously to examine the cross-sectional une
751751
We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
752752

753753
```{code-cell} ipython3
754-
# Create vectorized version of update_agent
755-
# The last parameter is now w_bar (scalar) instead of σ (array)
754+
# Create vectorized version of update_agent.
755+
# Vectorize over key, status, wage_idx
756756
update_agents_vmap = jax.vmap(
757757
update_agent, in_axes=(0, 0, 0, None, None)
758758
)
@@ -761,76 +761,72 @@ update_agents_vmap = jax.vmap(
761761
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
762762

763763
```{code-cell} ipython3
764-
@partial(jax.jit, static_argnums=(3, 4))
764+
@jax.jit
765765
def _simulate_cross_section_compiled(
766766
key: jnp.ndarray,
767767
model: Model,
768768
w_bar: float,
769-
n_agents: int,
769+
initial_wage_indices: jnp.ndarray,
770+
initial_status_vec: jnp.ndarray,
770771
T: int
771772
):
772-
"""JIT-compiled core simulation loop using lax.fori_loop.
773-
Returns only the final employment state to save memory."""
773+
"""
774+
JIT-compiled core simulation loop for shifting the cross section
775+
using lax.fori_loop. Returns the final employment employment status
776+
cross-section.
777+
778+
"""
774779
n, w_vals, P, P_cumsum, β, c, α, γ = model
780+
n_agents = len(initial_wage_indices)
775781
776-
# Initialize arrays
777-
wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
778-
status = jnp.zeros(n_agents, dtype=jnp.int32)
779782
780783
def update(t, loop_state):
781-
key, status, wage_indices = loop_state
782-
783-
# Shift loop state forwards
784-
key, subkey = jax.random.split(key)
785-
agent_keys = jax.random.split(subkey, n_agents)
786-
784+
" Shift loop state forwards "
785+
status, wage_indices = loop_state
786+
step_key = jax.random.fold_in(key, t)
787+
agent_keys = jax.random.split(step_key, n_agents)
787788
status, wage_indices = update_agents_vmap(
788789
agent_keys, status, wage_indices, model, w_bar
789790
)
790-
791-
return key, status, wage_indices
791+
return status, wage_indices
792792
793793
# Run simulation using fori_loop
794-
initial_loop_state = (key, status, wage_indices)
794+
initial_loop_state = (initial_status_vec, initial_wage_indices)
795795
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
796796
797797
# Return only final employment state
798-
_, final_is_employed, _ = final_loop_state
798+
final_is_employed, _ = final_loop_state
799799
return final_is_employed
800800
801801
802802
def simulate_cross_section(
803-
model: Model,
804-
n_agents: int = 100_000,
805-
T: int = 200,
806-
seed: int = 42
803+
model: Model, # Model instance with parameters
804+
n_agents: int = 100_000, # Number of agents to simulate
805+
T: int = 200, # Length of burn-in
806+
seed: int = 42 # For reproducibility
807807
) -> float:
808808
"""
809-
Simulate employment paths for many agents and return final unemployment rate.
809+
Wrapper function for _simulate_cross_section_compiled.
810810
811-
Parameters:
812-
- model: Model instance with parameters
813-
- n_agents: Number of agents to simulate
814-
- T: Number of periods to simulate
815-
- seed: Random seed for reproducibility
811+
Push forward a cross-section for T periods and return the final
812+
cross-sectional unemployment rate.
816813
817-
Returns:
818-
- unemployment_rate: Fraction of agents unemployed at time T
819814
"""
820815
key = jax.random.PRNGKey(seed)
821816
822817
# Solve for optimal reservation wage
823818
v_u = vfi(model)
824819
w_bar = get_reservation_wage(v_u, model)
825820
826-
# Run JIT-compiled simulation
821+
# Initialize arrays
822+
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
823+
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
824+
827825
final_status = _simulate_cross_section_compiled(
828-
key, model, w_bar, n_agents, T
826+
key, model, w_bar, initial_wage_indices, initial_status_vec, T
829827
)
830828
831-
# Calculate unemployment rate at final period
832829
unemployment_rate = 1 - jnp.mean(final_status)
833-
834830
return unemployment_rate
835831
```
836832

@@ -850,8 +846,13 @@ def plot_cross_sectional_unemployment(
850846
key = jax.random.PRNGKey(42)
851847
v_u = vfi(model)
852848
w_bar = get_reservation_wage(v_u, model)
849+
850+
# Initialize arrays
851+
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
852+
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
853+
853854
final_status = _simulate_cross_section_compiled(
854-
key, model, w_bar, n_agents, t_snapshot
855+
key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot
855856
)
856857
857858
# Calculate unemployment rate

0 commit comments

Comments
 (0)