@@ -83,7 +83,9 @@ from typing import NamedTuple
8383from scipy.optimize import root
8484import jax.numpy as jnp
8585import jax
86- jax.config.update("jax_enable_x64", True) # Enable 64-bit precision
86+
87+ # Enable 64-bit precision
88+ jax.config.update("jax_enable_x64", True)
8789```
8890
8991## Fixed point computation using Newton's method
@@ -732,7 +734,7 @@ solution = root(
732734 lambda p: e(p, A, b, c),
733735 init_p,
734736 jac = lambda p: jacobian_e(p, A, b, c),
735- method="hybr"
737+ method="hybr",
736738)
737739```
738740
@@ -980,27 +982,29 @@ Note the error is very small.
980982We can also test our results on the known solution
981983
982984``` {code-cell} ipython3
983- A = jnp.array([[2.0, 0.0, 0.0],
984- [0.0, 2.0, 0.0],
985- [0.0, 0.0, 2.0]])
985+ A = jnp.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]])
986986
987987s = 0.3
988988α = 0.3
989989δ = 0.4
990990
991991init = jnp.repeat(1.0, 3)
992+ ```
992993
994+ ``` {code-cell} ipython3
995+ %%time
993996
994- %time k = newton(lambda k: multivariate_solow(k, A=A, s=s, α=α, δ=δ) - k, \
995- init)
997+ k = newton(lambda k: multivariate_solow(k, A=A, s=s, α=α, δ=δ) - k, init)
996998```
997999
9981000The result is very close to the ground truth but still slightly different.
9991001
10001002``` {code-cell} ipython3
1001- %time k = newton(lambda k: multivariate_solow(k, A=A, s=s, α=α, δ=δ) - k, \
1002- init,\
1003- tol=1e-7)
1003+ %%time
1004+
1005+ k = newton(
1006+ lambda k: multivariate_solow(k, A=A, s=s, α=α, δ=δ) - k, init, tol=1e-7
1007+ )
10041008```
10051009
10061010We can see it steps towards a more accurate solution.
@@ -1073,12 +1077,9 @@ Let's run through each initial guess and check the output
10731077
10741078attempt = 1
10751079for init in initLs:
1076- print(f'Attempt {attempt}: Starting value is {init} \n')
1077- %time p = newton(lambda p: e(p, A, b, c), \
1078- init, \
1079- tol=1e-15, \
1080- max_iter=15)
1081- print('-'*64)
1080+ print(f"Attempt {attempt}: Starting value is {init} \n")
1081+ %time p = newton(lambda p: e(p, A, b, c), init, tol=1e-15, max_iter=15)
1082+ print("-" * 64)
10821083 attempt += 1
10831084```
10841085
0 commit comments