1
+ import numpy as np
2
+ from numba import jit , njit
3
+ from collections import namedtuple
4
+
5
+ __all__ = ['newton' , 'newton_secant' ]
6
+
7
+ _ECONVERGED = 0
8
+ _ECONVERR = - 1
9
+
10
+ results = namedtuple ('results' ,
11
+ ('root function_calls iterations converged' ))
12
+
13
+ @njit
14
+ def _results (r ):
15
+ r"""Select from a tuple of(root, funccalls, iterations, flag)"""
16
+ x , funcalls , iterations , flag = r
17
+ return results (x , funcalls , iterations , flag == 0 )
18
+
19
+
20
+ @njit
21
+ def newton (func , x0 , fprime , args = (), tol = 1.48e-8 , maxiter = 50 ,
22
+ disp = True ):
23
+ """
24
+ Find a zero from the Newton-Raphson method using the jitted version of
25
+ Scipy's newton for scalars. Note that this does not provide an alternative
26
+ method such as secant. Thus, it is important that `fprime` can be provided.
27
+
28
+ Note that `func` and `fprime` must be jitted via Numba.
29
+ They are recommended to be `njit` for performance.
30
+
31
+ Parameters
32
+ ----------
33
+ func : callable and jitted
34
+ The function whose zero is wanted. It must be a function of a
35
+ single variable of the form f(x,a,b,c...), where a,b,c... are extra
36
+ arguments that can be passed in the `args` parameter.
37
+ x0 : float
38
+ An initial estimate of the zero that should be somewhere near the
39
+ actual zero.
40
+ fprime : callable and jitted
41
+ The derivative of the function (when available and convenient).
42
+ args : tuple, optional
43
+ Extra arguments to be used in the function call.
44
+ tol : float, optional
45
+ The allowable error of the zero value.
46
+ maxiter : int, optional
47
+ Maximum number of iterations.
48
+ disp : bool, optional
49
+ If True, raise a RuntimeError if the algorithm didn't converge
50
+
51
+
52
+ Returns
53
+ -------
54
+ results : namedtuple
55
+ root - Estimated location where function is zero.
56
+ function_calls - Number of times the function was called.
57
+ iterations - Number of iterations needed to find the root.
58
+ converged - True if the routine converged
59
+ """
60
+
61
+ if tol <= 0 :
62
+ raise ValueError ("tol is too small <= 0" )
63
+ if maxiter < 1 :
64
+ raise ValueError ("maxiter must be greater than 0" )
65
+
66
+ # Convert to float (don't use float(x0); this works also for complex x0)
67
+ p0 = 1.0 * x0
68
+ funcalls = 0
69
+
70
+ # Newton-Raphson method
71
+ for itr in range (maxiter ):
72
+ # first evaluate fval
73
+ fval = func (p0 , * args )
74
+ funcalls += 1
75
+ # If fval is 0, a root has been found, then terminate
76
+ if fval == 0 :
77
+ return _results ((p0 , funcalls , itr , _ECONVERGED ))
78
+ fder = fprime (p0 , * args )
79
+ funcalls += 1
80
+ if fder == 0 :
81
+ # derivative is zero
82
+ return _results ((p0 , funcalls , itr + 1 , _ECONVERR ))
83
+ newton_step = fval / fder
84
+ # Newton step
85
+ p = p0 - newton_step
86
+ if abs (p - p0 ) < tol :
87
+ return _results ((p , funcalls , itr + 1 , _ECONVERGED ))
88
+ p0 = p
89
+
90
+ if disp :
91
+ msg = "Failed to converge"
92
+ raise RuntimeError (msg )
93
+
94
+ return _results ((p , funcalls , itr + 1 , _ECONVERR ))
95
+
96
+
97
+ @njit
98
+ def newton_secant (func , x0 , args = (), tol = 1.48e-8 , maxiter = 50 ,
99
+ disp = True ):
100
+ """
101
+ Find a zero from the secant method using the jitted version of
102
+ Scipy's secant method.
103
+
104
+ Note that `func` must be jitted via Numba.
105
+
106
+ Parameters
107
+ ----------
108
+ func : callable and jitted
109
+ The function whose zero is wanted. It must be a function of a
110
+ single variable of the form f(x,a,b,c...), where a,b,c... are extra
111
+ arguments that can be passed in the `args` parameter.
112
+ x0 : float
113
+ An initial estimate of the zero that should be somewhere near the
114
+ actual zero.
115
+ args : tuple, optional
116
+ Extra arguments to be used in the function call.
117
+ tol : float, optional
118
+ The allowable error of the zero value.
119
+ maxiter : int, optional
120
+ Maximum number of iterations.
121
+ disp : bool, optional
122
+ If True, raise a RuntimeError if the algorithm didn't converge.
123
+
124
+
125
+ Returns
126
+ -------
127
+ results : namedtuple
128
+ root - Estimated location where function is zero.
129
+ function_calls - Number of times the function was called.
130
+ iterations - Number of iterations needed to find the root.
131
+ converged - True if the routine converged
132
+ """
133
+
134
+ if tol <= 0 :
135
+ raise ValueError ("tol is too small <= 0" )
136
+ if maxiter < 1 :
137
+ raise ValueError ("maxiter must be greater than 0" )
138
+
139
+ # Convert to float (don't use float(x0); this works also for complex x0)
140
+ p0 = 1.0 * x0
141
+ funcalls = 0
142
+
143
+ # Secant method
144
+ if x0 >= 0 :
145
+ p1 = x0 * (1 + 1e-4 ) + 1e-4
146
+ else :
147
+ p1 = x0 * (1 + 1e-4 ) - 1e-4
148
+ q0 = func (p0 , * args )
149
+ funcalls += 1
150
+ q1 = func (p1 , * args )
151
+ funcalls += 1
152
+ for itr in range (maxiter ):
153
+ if q1 == q0 :
154
+ p = (p1 + p0 ) / 2.0
155
+ return _results ((p , funcalls , itr + 1 , _ECONVERGED ))
156
+ else :
157
+ p = p1 - q1 * (p1 - p0 ) / (q1 - q0 )
158
+ if np .abs (p - p1 ) < tol :
159
+ return _results ((p , funcalls , itr + 1 , _ECONVERGED ))
160
+ p0 = p1
161
+ q0 = q1
162
+ p1 = p
163
+ q1 = func (p1 , * args )
164
+ funcalls += 1
165
+
166
+ if disp :
167
+ msg = "Failed to converge"
168
+ raise RuntimeError (msg )
0 commit comments