@@ -240,7 +240,7 @@ Here's a function to compute an approximation to the fixed point of $Q$.
240240@jax.jit
241241def compute_fixed_point(model, tol=1e-4, max_iter=1000):
242242 """
243- Compute an approximation to the fixed point of Q using JAX while_loop .
243+ Compute an approximation to the fixed point of Q.
244244 """
245245
246246 def cond_fun(state):
@@ -303,7 +303,8 @@ for c in c_vals:
303303 model = create_job_search_model(c=c)
304304 f_star = compute_fixed_point(model)
305305 res_wage_function = jnp.exp(f_star * (1 - model.β))
306- ax.plot(model.z_grid, res_wage_function, label=rf"$\bar w$ at $c = {c}$")
306+ ax.plot(model.z_grid, res_wage_function,
307+ label=rf"$\bar w$ at $c = {c}$")
307308
308309ax.set(xlabel="$z$", ylabel="wage")
309310ax.legend()
@@ -323,7 +324,7 @@ For simplicity we’ll fix the initial state at $z_t = 0$.
323324def compute_unemployment_duration(model,
324325 key=jr.PRNGKey(1234), num_reps=100_000):
325326 """
326- Compute expected unemployment duration using JAX .
327+ Compute expected unemployment duration.
327328 """
328329 f_star = compute_fixed_point(model)
329330 μ, s, d = model.μ, model.s, model.d
0 commit comments