Skip to content

Commit 85f9eb3

Browse files
author
Supavit Dumrongprechachan
committed
Add jitted newton methods for root finding
Based on Scipy's newton methods, we add two variants of newton root finding - Newton Raphson and secant method. All methods are jitted through Numba with nopython mode. Relevant information (apart from the root) is also returned in the result as python namedtuple since Numba supports it.
1 parent d6130c1 commit 85f9eb3

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

quantecon/optimize/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"""
44

55
from .scalar_maximization import brent_max
6-
6+
from .root_finding import newton, newton_secant

quantecon/optimize/root_finding.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import numpy as np
2+
from numba import jit, njit
3+
from collections import namedtuple
4+
5+
__all__ = ['newton', 'newton_secant']
6+
7+
_ECONVERGED = 0
8+
_ECONVERR = -1
9+
10+
results = namedtuple('results',
11+
('root function_calls iterations converged'))
12+
13+
@njit
14+
def _results(r):
15+
r"""Select from a tuple of(root, funccalls, iterations, flag)"""
16+
x, funcalls, iterations, flag = r
17+
return results(x, funcalls, iterations, flag == 0)
18+
19+
20+
@njit
21+
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
22+
disp=True):
23+
"""
24+
Find a zero from the Newton-Raphson method using the jitted version of
25+
Scipy's newton for scalars. Note that this does not provide an alternative
26+
method such as secant. Thus, it is important that `fprime` can be provided.
27+
28+
Note that `func` and `fprime` must be jitted via Numba.
29+
They are recommended to be `njit` for performance.
30+
31+
Parameters
32+
----------
33+
func : callable and jitted
34+
The function whose zero is wanted. It must be a function of a
35+
single variable of the form f(x,a,b,c...), where a,b,c... are extra
36+
arguments that can be passed in the `args` parameter.
37+
x0 : float
38+
An initial estimate of the zero that should be somewhere near the
39+
actual zero.
40+
fprime : callable and jitted
41+
The derivative of the function (when available and convenient).
42+
args : tuple, optional
43+
Extra arguments to be used in the function call.
44+
tol : float, optional
45+
The allowable error of the zero value.
46+
maxiter : int, optional
47+
Maximum number of iterations.
48+
disp : bool, optional
49+
If True, raise a RuntimeError if the algorithm didn't converge
50+
51+
52+
Returns
53+
-------
54+
results : namedtuple
55+
root - Estimated location where function is zero.
56+
function_calls - Number of times the function was called.
57+
iterations - Number of iterations needed to find the root.
58+
converged - True if the routine converged
59+
"""
60+
61+
if tol <= 0:
62+
raise ValueError("tol is too small <= 0")
63+
if maxiter < 1:
64+
raise ValueError("maxiter must be greater than 0")
65+
66+
# Convert to float (don't use float(x0); this works also for complex x0)
67+
p0 = 1.0 * x0
68+
funcalls = 0
69+
70+
# Newton-Raphson method
71+
for itr in range(maxiter):
72+
# first evaluate fval
73+
fval = func(p0, *args)
74+
funcalls += 1
75+
# If fval is 0, a root has been found, then terminate
76+
if fval == 0:
77+
return _results((p0, funcalls, itr, _ECONVERGED))
78+
fder = fprime(p0, *args)
79+
funcalls += 1
80+
if fder == 0:
81+
# derivative is zero
82+
return _results((p0, funcalls, itr + 1, _ECONVERR))
83+
newton_step = fval / fder
84+
# Newton step
85+
p = p0 - newton_step
86+
if abs(p - p0) < tol:
87+
return _results((p, funcalls, itr + 1, _ECONVERGED))
88+
p0 = p
89+
90+
if disp:
91+
msg = "Failed to converge"
92+
raise RuntimeError(msg)
93+
94+
return _results((p, funcalls, itr + 1, _ECONVERR))
95+
96+
97+
@njit
98+
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
99+
disp=True):
100+
"""
101+
Find a zero from the secant method using the jitted version of
102+
Scipy's secant method.
103+
104+
Note that `func` must be jitted via Numba.
105+
106+
Parameters
107+
----------
108+
func : callable and jitted
109+
The function whose zero is wanted. It must be a function of a
110+
single variable of the form f(x,a,b,c...), where a,b,c... are extra
111+
arguments that can be passed in the `args` parameter.
112+
x0 : float
113+
An initial estimate of the zero that should be somewhere near the
114+
actual zero.
115+
args : tuple, optional
116+
Extra arguments to be used in the function call.
117+
tol : float, optional
118+
The allowable error of the zero value.
119+
maxiter : int, optional
120+
Maximum number of iterations.
121+
disp : bool, optional
122+
If True, raise a RuntimeError if the algorithm didn't converge.
123+
124+
125+
Returns
126+
-------
127+
results : namedtuple
128+
root - Estimated location where function is zero.
129+
function_calls - Number of times the function was called.
130+
iterations - Number of iterations needed to find the root.
131+
converged - True if the routine converged
132+
"""
133+
134+
if tol <= 0:
135+
raise ValueError("tol is too small <= 0")
136+
if maxiter < 1:
137+
raise ValueError("maxiter must be greater than 0")
138+
139+
# Convert to float (don't use float(x0); this works also for complex x0)
140+
p0 = 1.0 * x0
141+
funcalls = 0
142+
143+
# Secant method
144+
if x0 >= 0:
145+
p1 = x0 * (1 + 1e-4) + 1e-4
146+
else:
147+
p1 = x0 * (1 + 1e-4) - 1e-4
148+
q0 = func(p0, *args)
149+
funcalls += 1
150+
q1 = func(p1, *args)
151+
funcalls += 1
152+
for itr in range(maxiter):
153+
if q1 == q0:
154+
p = (p1 + p0) / 2.0
155+
return _results((p, funcalls, itr + 1, _ECONVERGED))
156+
else:
157+
p = p1 - q1 * (p1 - p0) / (q1 - q0)
158+
if np.abs(p - p1) < tol:
159+
return _results((p, funcalls, itr + 1, _ECONVERGED))
160+
p0 = p1
161+
q0 = q1
162+
p1 = p
163+
q1 = func(p1, *args)
164+
funcalls += 1
165+
166+
if disp:
167+
msg = "Failed to converge"
168+
raise RuntimeError(msg)

0 commit comments

Comments
 (0)