Skip to content

Commit 0317e4a

Browse files
committed
modified: lectures/markov_asset.md
1 parent c227f9c commit 0317e4a

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

lectures/markov_asset.md

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ kernelspec:
3535
"Asset pricing is all about covariances" -- Lars Peter Hansen
3636
```
3737

38+
```{admonition} GPU
39+
:class: warning
40+
41+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
42+
43+
Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
44+
45+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
46+
```
47+
3848
In addition to what's in Anaconda, this lecture will need the following libraries:
3949

4050
```{code-cell} ipython
@@ -976,10 +986,12 @@ $$
976986
977987
Consider the following primitives
978988
979-
```{code-cell} python3
989+
```{code-cell} ipython3
980990
n = 5 # Size of State Space
981-
P = np.full((n, n), 0.0125)
982-
P[range(n), range(n)] += 1 - P.sum(1)
991+
P = jnp.full((n, n), 0.0125)
992+
P = P.at[jnp.arange(n), jnp.arange(n)].set(
993+
P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1)
994+
)
983995
# State values of the Markov chain
984996
s = np.array([0.95, 0.975, 1.0, 1.025, 1.05])
985997
γ = 2.0
@@ -1004,11 +1016,13 @@ Do the same for
10041016
10051017
First, let's enter the parameters:
10061018
1007-
```{code-cell} python3
1019+
```{code-cell} ipython3
10081020
n = 5
1009-
P = np.full((n, n), 0.0125)
1010-
P[range(n), range(n)] += 1 - P.sum(1)
1011-
s = np.array([0.95, 0.975, 1.0, 1.025, 1.05]) # State values
1021+
P = jnp.full((n, n), 0.0125)
1022+
P = P.at[jnp.arange(n), jnp.arange(n)].set(
1023+
P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1)
1024+
)
1025+
s = jnp.array([0.95, 0.975, 1.0, 1.025, 1.05]) # State values
10121026
mc = qe.MarkovChain(P, state_values=s)
10131027
10141028
γ = 2.0
@@ -1020,27 +1034,27 @@ p_s = 150.0
10201034
Next, we'll create an instance of `AssetPriceModel` to feed into the
10211035
functions
10221036
1023-
```{code-cell} python3
1024-
apm = AssetPriceModel(β=β, mc=mc, γ=γ, g=lambda x: x)
1037+
```{code-cell} ipython3
1038+
apm = create_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ)
10251039
```
10261040
10271041
Now we just need to call the relevant functions on the data:
10281042
1029-
```{code-cell} python3
1043+
```{code-cell} ipython3
10301044
tree_price(apm)
10311045
```
10321046
1033-
```{code-cell} python3
1047+
```{code-cell} ipython3
10341048
consol_price(apm, ζ)
10351049
```
10361050
1037-
```{code-cell} python3
1051+
```{code-cell} ipython3
10381052
call_option(apm, ζ, p_s)
10391053
```
10401054
10411055
Let's show the last two functions as a plot
10421056
1043-
```{code-cell} python3
1057+
```{code-cell} ipython3
10441058
fig, ax = plt.subplots()
10451059
ax.plot(s, consol_price(apm, ζ), label='consol')
10461060
ax.plot(s, call_option(apm, ζ, p_s), label='call option')
@@ -1101,7 +1115,7 @@ Is one higher than the other? Can you give intuition?
11011115
11021116
Here's a suitable function:
11031117
1104-
```{code-cell} python3
1118+
```{code-cell} ipython3
11051119
def finite_horizon_call_option(ap, ζ, p_s, k):
11061120
"""
11071121
Computes k period option value.
@@ -1111,15 +1125,16 @@ def finite_horizon_call_option(ap, ζ, p_s, k):
11111125
M = P * ap.g(y)**(- γ)
11121126
11131127
# Make sure that a unique solution exists
1114-
ap.test_stability(M)
1115-
1128+
test_stability(M, β)
11161129
11171130
# Compute option price
11181131
p = consol_price(ap, ζ)
1119-
w = np.zeros(ap.n)
1120-
for i in range(k):
1132+
def step(i, w):
11211133
# Maximize across columns
1122-
w = np.maximum(β * M @ w, p - p_s)
1134+
w = jnp.maximum(β * M @ w, p - p_s)
1135+
return w
1136+
1137+
w = jax.lax.fori_loop(0, k, step, jnp.zeros(ap.n))
11231138
11241139
return w
11251140
```

0 commit comments

Comments
 (0)