Skip to content

Commit 211ec24

Browse files
committed
Made functions more JAX-compatible
1 parent 02dfe6e commit 211ec24

File tree

1 file changed

+50
-23
lines changed

1 file changed

+50
-23
lines changed

lectures/markov_asset.md

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ import quantecon as qe
8888
import jax
8989
import jax.numpy as jnp
9090
from jax.numpy.linalg import eigvals, solve
91+
from jax.experimental import checkify
9192
from typing import NamedTuple
9293
```
9394

@@ -567,6 +568,20 @@ We will define a function tree_price to compute $v$ given parameters stored in
567568
the class AssetPriceModel
568569
569570
```{code-cell} ipython3
571+
class MarkovChain(NamedTuple):
572+
"""
573+
A class that stores the primitives of a Markov chain.
574+
Parameters
575+
----------
576+
P : jnp.ndarray
577+
Transition matrix
578+
state_values : jnp.ndarray
579+
The values associated with each state
580+
"""
581+
P: jnp.ndarray
582+
state_values: jnp.ndarray
583+
584+
570585
class AssetPriceModel(NamedTuple):
571586
"""
572587
A class that stores the primitives of the asset pricing model.
@@ -584,34 +599,42 @@ class AssetPriceModel(NamedTuple):
584599
n: int
585600
The number of states
586601
"""
587-
mc: qe.MarkovChain
602+
mc: MarkovChain
588603
g: callable
589604
β: float
590605
γ: float
591606
n: int
592607
593608
594-
def create_ap_model(mc=None, g=jnp.exp, β=0.96, γ=2.0):
595-
"""Create an AssetPriceModel class"""
596-
if mc is None:
597-
n, ρ, σ = 25, 0.9, 0.02
598-
mc = qe.tauchen(n, ρ, σ)
599-
else:
600-
mc = mc
601-
n = mc.P.shape[0]
609+
def create_ap_model(g=jnp.exp, β=0.96, γ=2.0):
610+
"""Create an AssetPriceModel class using standard Markov chain."""
611+
n, ρ, σ = 25, 0.9, 0.02
612+
qe_mc = qe.tauchen(n, ρ, σ)
613+
P = jnp.array(qe_mc.P)
614+
state_values = jnp.array(qe_mc.state_values)
615+
mc = MarkovChain(P=P, state_values=state_values)
602616
603617
return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n)
604618
605619
620+
def create_customized_ap_model(mc: MarkovChain, g=jnp.exp, β=0.96, γ=2.0):
621+
"""Create an AssetPriceModel class using a customized Markov chain."""
622+
n = mc.P.shape[0]
623+
return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n)
624+
625+
606626
def test_stability(Q, β):
607-
"""
608-
Stability test for a given matrix Q.
609-
"""
610-
sr = np.max(np.abs(eigvals(Q)))
611-
if not sr < 1 / β:
612-
msg = f"Spectral radius condition failed with radius = {sr}"
613-
raise ValueError(msg)
627+
"""Stability test for a given matrix Q."""
628+
sr = jnp.max(jnp.abs(eigvals(Q)))
629+
checkify.check(
630+
sr < 1 / β,
631+
"Spectral radius condition failed with radius = {sr}", sr=sr
632+
)
633+
return sr
634+
614635
636+
# Wrap the check function to be JIT-safe
637+
test_stability = checkify.checkify(test_stability, errors=checkify.user_checks)
615638
616639
def tree_price(ap):
617640
"""
@@ -633,7 +656,8 @@ def tree_price(ap):
633656
J = P * ap.g(y)**(1 - γ)
634657
635658
# Make sure that a unique solution exists
636-
test_stability(J, β)
659+
err, out = test_stability(J, β)
660+
err.throw()
637661
638662
# Compute v
639663
I = jnp.identity(ap.n)
@@ -661,7 +685,7 @@ states = ap.mc.state_values
661685
fig, ax = plt.subplots()
662686
663687
for γ in γs:
664-
tem_ap = create_ap_model(mc=ap.mc, g=ap.g, β=ap.β, γ=γ)
688+
tem_ap = create_customized_ap_model(mc=ap.mc, β=ap.β, γ=γ)
665689
v = tree_price(tem_ap)
666690
ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$")
667691
@@ -767,7 +791,8 @@ def consol_price(ap, ζ):
767791
M = P * ap.g(y)**(- γ)
768792
769793
# Make sure that a unique solution exists
770-
test_stability(M, β)
794+
err, _ = test_stability(M, β)
795+
err.throw()
771796
772797
# Compute price
773798
I = jnp.identity(ap.n)
@@ -879,15 +904,16 @@ def call_option(ap, ζ, p_s, ϵ=1e-7):
879904
M = P * ap.g(y)**(- γ)
880905
881906
# Make sure that a unique consol price exists
882-
test_stability(M, β)
907+
err, _ = test_stability(M, β)
908+
err.throw()
883909
884910
# Compute option price
885911
p = consol_price(ap, ζ)
886912
w = jnp.zeros(ap.n)
887913
error = ϵ + 1
888914
889915
def step(state):
890-
w, error = state
916+
w, _ = state
891917
# Maximize across columns
892918
w_new = jnp.maximum(β * M @ w, p - p_s)
893919
# Find maximal difference of each component and update
@@ -1062,7 +1088,7 @@ Next, we'll create an instance of `AssetPriceModel` to feed into the
10621088
functions
10631089
10641090
```{code-cell} ipython3
1065-
apm = create_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ)
1091+
apm = create_customized_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ)
10661092
```
10671093
10681094
Now we just need to call the relevant functions on the data:
@@ -1152,7 +1178,8 @@ def finite_horizon_call_option(ap, ζ, p_s, k):
11521178
M = P * ap.g(y)**(- γ)
11531179
11541180
# Make sure that a unique solution exists
1155-
test_stability(M, β)
1181+
err, _ = test_stability(M, β)
1182+
err.throw()
11561183
11571184
# Compute option price
11581185
p = consol_price(ap, ζ)

0 commit comments

Comments
 (0)