@@ -88,6 +88,7 @@ import quantecon as qe
8888import jax
8989import jax.numpy as jnp
9090from jax.numpy.linalg import eigvals, solve
91+ from jax.experimental import checkify
9192from typing import NamedTuple
9293```
9394
@@ -567,6 +568,20 @@ We will define a function tree_price to compute $v$ given parameters stored in
567568the class AssetPriceModel
568569
569570```{code-cell} ipython3
571+ class MarkovChain(NamedTuple):
572+ """
573+ A class that stores the primitives of a Markov chain.
574+ Parameters
575+ ----------
576+ P : jnp.ndarray
577+ Transition matrix
578+ state_values : jnp.ndarray
579+ The values associated with each state
580+ """
581+ P: jnp.ndarray
582+ state_values: jnp.ndarray
583+
584+
570585class AssetPriceModel(NamedTuple):
571586 """
572587 A class that stores the primitives of the asset pricing model.
@@ -584,34 +599,42 @@ class AssetPriceModel(NamedTuple):
584599 n: int
585600 The number of states
586601 """
587- mc: qe. MarkovChain
602+ mc: MarkovChain
588603 g: callable
589604 β: float
590605 γ: float
591606 n: int
592607
593608
594- def create_ap_model(mc=None, g=jnp.exp, β=0.96, γ=2.0):
595- """Create an AssetPriceModel class"""
596- if mc is None:
597- n, ρ, σ = 25, 0.9, 0.02
598- mc = qe.tauchen(n, ρ, σ)
599- else:
600- mc = mc
601- n = mc.P.shape[0]
609+ def create_ap_model(g=jnp.exp, β=0.96, γ=2.0):
610+ """Create an AssetPriceModel class using standard Markov chain."""
611+ n, ρ, σ = 25, 0.9, 0.02
612+ qe_mc = qe.tauchen(n, ρ, σ)
613+ P = jnp.array(qe_mc.P)
614+ state_values = jnp.array(qe_mc.state_values)
615+ mc = MarkovChain(P=P, state_values=state_values)
602616
603617 return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n)
604618
605619
620+ def create_customized_ap_model(mc: MarkovChain, g=jnp.exp, β=0.96, γ=2.0):
621+ """Create an AssetPriceModel class using a customized Markov chain."""
622+ n = mc.P.shape[0]
623+ return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n)
624+
625+
606626def test_stability(Q, β):
607- """
608- Stability test for a given matrix Q.
609- """
610- sr = np.max(np.abs(eigvals(Q)))
611- if not sr < 1 / β:
612- msg = f"Spectral radius condition failed with radius = {sr}"
613- raise ValueError(msg)
627+ """Stability test for a given matrix Q."""
628+ sr = jnp.max(jnp.abs(eigvals(Q)))
629+ checkify.check(
630+ sr < 1 / β,
631+ "Spectral radius condition failed with radius = {sr}", sr=sr
632+ )
633+ return sr
634+
614635
636+ # Wrap the check function to be JIT-safe
637+ test_stability = checkify.checkify(test_stability, errors=checkify.user_checks)
615638
616639def tree_price(ap):
617640 """
@@ -633,7 +656,8 @@ def tree_price(ap):
633656 J = P * ap.g(y)**(1 - γ)
634657
635658 # Make sure that a unique solution exists
636- test_stability(J, β)
659+ err, out = test_stability(J, β)
660+ err.throw()
637661
638662 # Compute v
639663 I = jnp.identity(ap.n)
@@ -661,7 +685,7 @@ states = ap.mc.state_values
661685fig, ax = plt.subplots()
662686
663687for γ in γs:
664- tem_ap = create_ap_model (mc=ap.mc, g=ap.g , β=ap.β, γ=γ)
688+ tem_ap = create_customized_ap_model (mc=ap.mc, β=ap.β, γ=γ)
665689 v = tree_price(tem_ap)
666690 ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$")
667691
@@ -767,7 +791,8 @@ def consol_price(ap, ζ):
767791 M = P * ap.g(y)**(- γ)
768792
769793 # Make sure that a unique solution exists
770- test_stability(M, β)
794+ err, _ = test_stability(M, β)
795+ err.throw()
771796
772797 # Compute price
773798 I = jnp.identity(ap.n)
@@ -879,15 +904,16 @@ def call_option(ap, ζ, p_s, ϵ=1e-7):
879904 M = P * ap.g(y)**(- γ)
880905
881906 # Make sure that a unique consol price exists
882- test_stability(M, β)
907+ err, _ = test_stability(M, β)
908+ err.throw()
883909
884910 # Compute option price
885911 p = consol_price(ap, ζ)
886912 w = jnp.zeros(ap.n)
887913 error = ϵ + 1
888914
889915 def step(state):
890- w, error = state
916+ w, _ = state
891917 # Maximize across columns
892918 w_new = jnp.maximum(β * M @ w, p - p_s)
893919 # Find maximal difference of each component and update
@@ -1062,7 +1088,7 @@ Next, we'll create an instance of `AssetPriceModel` to feed into the
10621088functions
10631089
10641090```{code-cell} ipython3
1065- apm = create_ap_model (mc=mc, g=lambda x: x, β=β, γ=γ)
1091+ apm = create_customized_ap_model (mc=mc, g=lambda x: x, β=β, γ=γ)
10661092```
10671093
10681094Now we just need to call the relevant functions on the data:
@@ -1152,7 +1178,8 @@ def finite_horizon_call_option(ap, ζ, p_s, k):
11521178 M = P * ap.g(y)**(- γ)
11531179
11541180 # Make sure that a unique solution exists
1155- test_stability(M, β)
1181+ err, _ = test_stability(M, β)
1182+ err.throw()
11561183
11571184 # Compute option price
11581185 p = consol_price(ap, ζ)
0 commit comments