@@ -751,8 +751,8 @@ Now let's simulate many agents simultaneously to examine the cross-sectional une
751751We first create a vectorized version of ` update_agent ` to efficiently update all agents in parallel:
752752
753753``` {code-cell} ipython3
754- # Create vectorized version of update_agent
755- # The last parameter is now w_bar (scalar) instead of σ (array)
754+ # Create vectorized version of update_agent.
755+ # Vectorize over key, status, wage_idx
756756update_agents_vmap = jax.vmap(
757757 update_agent, in_axes=(0, 0, 0, None, None)
758758)
@@ -761,76 +761,72 @@ update_agents_vmap = jax.vmap(
761761Next we define the core simulation function, which uses ` lax.fori_loop ` to efficiently iterate many agents forward in time:
762762
763763``` {code-cell} ipython3
764- @partial( jax.jit, static_argnums=(3, 4))
764+ @jax.jit
765765def _simulate_cross_section_compiled(
766766 key: jnp.ndarray,
767767 model: Model,
768768 w_bar: float,
769- n_agents: int,
769+ initial_wage_indices: jnp.ndarray,
770+ initial_status_vec: jnp.ndarray,
770771 T: int
771772 ):
772- """JIT-compiled core simulation loop using lax.fori_loop.
773- Returns only the final employment state to save memory."""
773+ """
774+ JIT-compiled core simulation loop for shifting the cross section
775+ using lax.fori_loop. Returns the final employment employment status
776+ cross-section.
777+
778+ """
774779 n, w_vals, P, P_cumsum, β, c, α, γ = model
780+ n_agents = len(initial_wage_indices)
775781
776- # Initialize arrays
777- wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
778- status = jnp.zeros(n_agents, dtype=jnp.int32)
779782
780783 def update(t, loop_state):
781- key, status, wage_indices = loop_state
782-
783- # Shift loop state forwards
784- key, subkey = jax.random.split(key)
785- agent_keys = jax.random.split(subkey, n_agents)
786-
784+ " Shift loop state forwards "
785+ status, wage_indices = loop_state
786+ step_key = jax.random.fold_in(key, t)
787+ agent_keys = jax.random.split(step_key, n_agents)
787788 status, wage_indices = update_agents_vmap(
788789 agent_keys, status, wage_indices, model, w_bar
789790 )
790-
791- return key, status, wage_indices
791+ return status, wage_indices
792792
793793 # Run simulation using fori_loop
794- initial_loop_state = (key, status, wage_indices )
794+ initial_loop_state = (initial_status_vec, initial_wage_indices )
795795 final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
796796
797797 # Return only final employment state
798- _, final_is_employed, _ = final_loop_state
798+ final_is_employed, _ = final_loop_state
799799 return final_is_employed
800800
801801
802802def simulate_cross_section(
803- model: Model,
804- n_agents: int = 100_000,
805- T: int = 200,
806- seed: int = 42
803+ model: Model, # Model instance with parameters
804+ n_agents: int = 100_000, # Number of agents to simulate
805+ T: int = 200, # Length of burn-in
806+ seed: int = 42 # For reproducibility
807807 ) -> float:
808808 """
809- Simulate employment paths for many agents and return final unemployment rate .
809+ Wrapper function for _simulate_cross_section_compiled .
810810
811- Parameters:
812- - model: Model instance with parameters
813- - n_agents: Number of agents to simulate
814- - T: Number of periods to simulate
815- - seed: Random seed for reproducibility
811+ Push forward a cross-section for T periods and return the final
812+ cross-sectional unemployment rate.
816813
817- Returns:
818- - unemployment_rate: Fraction of agents unemployed at time T
819814 """
820815 key = jax.random.PRNGKey(seed)
821816
822817 # Solve for optimal reservation wage
823818 v_u = vfi(model)
824819 w_bar = get_reservation_wage(v_u, model)
825820
826- # Run JIT-compiled simulation
821+ # Initialize arrays
822+ initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
823+ initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
824+
827825 final_status = _simulate_cross_section_compiled(
828- key, model, w_bar, n_agents , T
826+ key, model, w_bar, initial_wage_indices, initial_status_vec , T
829827 )
830828
831- # Calculate unemployment rate at final period
832829 unemployment_rate = 1 - jnp.mean(final_status)
833-
834830 return unemployment_rate
835831```
836832
@@ -850,8 +846,13 @@ def plot_cross_sectional_unemployment(
850846 key = jax.random.PRNGKey(42)
851847 v_u = vfi(model)
852848 w_bar = get_reservation_wage(v_u, model)
849+
850+ # Initialize arrays
851+ initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
852+ initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
853+
853854 final_status = _simulate_cross_section_compiled(
854- key, model, w_bar, n_agents , t_snapshot
855+ key, model, w_bar, initial_wage_indices, initial_status_vec , t_snapshot
855856 )
856857
857858 # Calculate unemployment rate
0 commit comments