Add JAX implementation to ifp_advanced lecture #705
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a comprehensive JAX implementation to the Income Fluctuation Problem II lecture (
ifp_advanced.md), providing an alternative high-performance implementation alongside the existing Numba version.Changes Made
1. Section Restructuring
2. JAX Implementation Details
Core Components:
IFP_JAX: Implemented as aNamedTuple(instead of a regular class) to ensure compatibility with JAX's JIT compilation and hashability requirementsu_prime(c, γ): Marginal utilityu_prime_inv(c, γ): Inverse marginal utilityR(z, ζ, a_r, b_r): Gross return on assetsY(z, η, a_y, b_y): Labor incomecreate_ifp_jax(): Factory function to construct IFP_JAX instances with parameter validationK_jax(): Coleman-Reffett operator using JAX's JIT compilation and vectorizationsolve_model_time_iter_jax(): Time iteration solver for JAX3. Numerical Validation
Added comprehensive comparison section showing:
4. Technical Improvements
jax.config.update("jax_enable_x64", True))jax.jitasjax_jitto avoid overwritingnumba.jitTest Results
Both implementations successfully converge:
Numerical comparison:
The solutions are essentially identical, with minor differences arising from:
Benefits
The JAX implementation offers:
Validation
✅ Script runs successfully with exit code 0
✅ All plots generate correctly
✅ Both implementations produce consistent results
✅ No breaking changes to existing Numba implementation
🤖 Generated with Claude Code