Skip to content

Commit 876c797

Browse files
author
Supavit Dumrongprechachan
committed
Add newton halley method and modify returning
The third and last newton method for root finding is now added with its test cases. Each newton method is modified to include only a single return statement.
1 parent 2f611cc commit 876c797

File tree

3 files changed

+134
-17
lines changed

3 files changed

+134
-17
lines changed

quantecon/optimize/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"""
44

55
from .scalar_maximization import brent_max
6-
from .root_finding import newton, newton_secant
6+
from .root_finding import newton, newton_halley, newton_secant

quantecon/optimize/root_finding.py

Lines changed: 104 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from numba import jit, njit
33
from collections import namedtuple
44

5-
__all__ = ['newton', 'newton_secant']
5+
__all__ = ['newton', 'newton_halley', 'newton_secant']
66

77
_ECONVERGED = 0
88
_ECONVERR = -1
@@ -16,7 +16,6 @@ def _results(r):
1616
x, funcalls, iterations, flag = r
1717
return results(x, funcalls, iterations, flag == 0)
1818

19-
2019
@njit
2120
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
2221
disp=True):
@@ -48,7 +47,6 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
4847
disp : bool, optional
4948
If True, raise a RuntimeError if the algorithm didn't converge
5049
51-
5250
Returns
5351
-------
5452
results : namedtuple
@@ -66,33 +64,120 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
6664
# Convert to float (don't use float(x0); this works also for complex x0)
6765
p0 = 1.0 * x0
6866
funcalls = 0
69-
67+
status = _ECONVERR
68+
7069
# Newton-Raphson method
7170
for itr in range(maxiter):
7271
# first evaluate fval
7372
fval = func(p0, *args)
7473
funcalls += 1
7574
# If fval is 0, a root has been found, then terminate
7675
if fval == 0:
77-
return _results((p0, funcalls, itr, _ECONVERGED))
76+
status = _ECONVERGED
77+
p = p0
78+
itr -= 1
79+
break
7880
fder = fprime(p0, *args)
7981
funcalls += 1
82+
# derivative is zero, not converged
8083
if fder == 0:
81-
# derivative is zero
82-
return _results((p0, funcalls, itr + 1, _ECONVERR))
84+
p = p0
85+
break
8386
newton_step = fval / fder
8487
# Newton step
8588
p = p0 - newton_step
8689
if abs(p - p0) < tol:
87-
return _results((p, funcalls, itr + 1, _ECONVERGED))
90+
status = _ECONVERGED
91+
break
8892
p0 = p
8993

90-
if disp:
94+
if disp and status == _ECONVERR:
9195
msg = "Failed to converge"
9296
raise RuntimeError(msg)
9397

94-
return _results((p, funcalls, itr + 1, _ECONVERR))
98+
return _results((p, funcalls, itr + 1, status))
9599

100+
@njit
101+
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
102+
maxiter=50, disp=True):
103+
"""
104+
Find a zero from Halley's method using the jitted version of
105+
Scipy's.
106+
107+
`func`, `fprime`, `fprime2` must be jitted via Numba.
108+
109+
Parameters
110+
----------
111+
func : callable and jitted
112+
The function whose zero is wanted. It must be a function of a
113+
single variable of the form f(x,a,b,c...), where a,b,c... are extra
114+
arguments that can be passed in the `args` parameter.
115+
x0 : float
116+
An initial estimate of the zero that should be somewhere near the
117+
actual zero.
118+
fprime : callable and jitted
119+
The derivative of the function (when available and convenient).
120+
fprime2 : callable and jitted
121+
The second order derivative of the function
122+
args : tuple, optional
123+
Extra arguments to be used in the function call.
124+
tol : float, optional
125+
The allowable error of the zero value.
126+
maxiter : int, optional
127+
Maximum number of iterations.
128+
disp : bool, optional
129+
If True, raise a RuntimeError if the algorithm didn't converge
130+
131+
Returns
132+
-------
133+
results : namedtuple
134+
root - Estimated location where function is zero.
135+
function_calls - Number of times the function was called.
136+
iterations - Number of iterations needed to find the root.
137+
converged - True if the routine converged
138+
"""
139+
140+
if tol <= 0:
141+
raise ValueError("tol is too small <= 0")
142+
if maxiter < 1:
143+
raise ValueError("maxiter must be greater than 0")
144+
145+
# Convert to float (don't use float(x0); this works also for complex x0)
146+
p0 = 1.0 * x0
147+
funcalls = 0
148+
status = _ECONVERR
149+
150+
# Halley Method
151+
for itr in range(maxiter):
152+
# first evaluate fval
153+
fval = func(p0, *args)
154+
funcalls += 1
155+
# If fval is 0, a root has been found, then terminate
156+
if fval == 0:
157+
status = _ECONVERGED
158+
p = p0
159+
itr -= 1
160+
break
161+
fder = fprime(p0, *args)
162+
funcalls += 1
163+
# derivative is zero, not converged
164+
if fder == 0:
165+
p = p0
166+
break
167+
newton_step = fval / fder
168+
# Halley's variant
169+
fder2 = fprime2(p0, *args)
170+
p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder)
171+
if abs(p - p0) < tol:
172+
status = _ECONVERGED
173+
break
174+
p0 = p
175+
176+
if disp and status == _ECONVERR:
177+
msg = "Failed to converge"
178+
raise RuntimeError(msg)
179+
180+
return _results((p, funcalls, itr + 1, status))
96181

97182
@njit
98183
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
@@ -121,7 +206,6 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
121206
disp : bool, optional
122207
If True, raise a RuntimeError if the algorithm didn't converge.
123208
124-
125209
Returns
126210
-------
127211
results : namedtuple
@@ -139,6 +223,7 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
139223
# Convert to float (don't use float(x0); this works also for complex x0)
140224
p0 = 1.0 * x0
141225
funcalls = 0
226+
status = _ECONVERR
142227

143228
# Secant method
144229
if x0 >= 0:
@@ -152,17 +237,21 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
152237
for itr in range(maxiter):
153238
if q1 == q0:
154239
p = (p1 + p0) / 2.0
155-
return _results((p, funcalls, itr + 1, _ECONVERGED))
240+
status = _ECONVERGED
241+
break
156242
else:
157243
p = p1 - q1 * (p1 - p0) / (q1 - q0)
158244
if np.abs(p - p1) < tol:
159-
return _results((p, funcalls, itr + 1, _ECONVERGED))
245+
status = _ECONVERGED
246+
break
160247
p0 = p1
161248
q0 = q1
162249
p1 = p
163250
q1 = func(p1, *args)
164251
funcalls += 1
165252

166-
if disp:
253+
if disp and status == _ECONVERR:
167254
msg = "Failed to converge"
168-
raise RuntimeError(msg)
255+
raise RuntimeError(msg)
256+
257+
return _results((p, funcalls, itr + 1, status))

quantecon/optimize/tests/test_root_finding.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from numpy.testing import assert_almost_equal, assert_allclose
33
from numba import njit
44

5-
from quantecon.optimize import newton, newton_secant
5+
from quantecon.optimize import newton, newton_halley, newton_secant
66

77
@njit
88
def func(x):
@@ -19,6 +19,12 @@ def func_prime(x):
1919
"""
2020
return (3*x**2)
2121

22+
@njit
23+
def func_prime2(x):
24+
"""
25+
Second order derivative for func.
26+
"""
27+
return 6*x
2228

2329
@njit
2430
def func_two(x):
@@ -35,6 +41,13 @@ def func_two_prime(x):
3541
"""
3642
return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1
3743

44+
@njit
45+
def func_two_prime2(x):
46+
"""
47+
Second order derivative for func_two
48+
"""
49+
return 380*x**18 - 16*np.sin(4*(x - 1/4))
50+
3851

3952
def test_newton_basic():
4053
"""
@@ -64,6 +77,21 @@ def test_newton_hard():
6477
fval = newton(func_two, 0.4, func_two_prime)
6578
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)
6679

80+
def test_halley_basic():
81+
"""
82+
Basic test for halley method
83+
"""
84+
true_fval = 1.0
85+
fval = newton_halley(func, 5, func_prime, func_prime2)
86+
assert_almost_equal(true_fval, fval.root, decimal=4)
87+
88+
def test_halley_hard():
89+
"""
90+
Harder test for halley method
91+
"""
92+
true_fval = 0.408
93+
fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2)
94+
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)
6795

6896
def test_secant_basic():
6997
"""

0 commit comments

Comments
 (0)