|
27 | 27 | from odl.solvers.util import ConstantLineSearch |
28 | 28 |
|
29 | 29 |
|
30 | | -__all__ = ('steepest_descent',) |
| 30 | +__all__ = ('steepest_descent', 'adam') |
31 | 31 |
|
32 | 32 |
|
33 | 33 | # TODO: update all docs |
@@ -110,6 +110,77 @@ def steepest_descent(f, x, line_search=1.0, maxiter=1000, tol=1e-16, |
110 | 110 | callback(x) |
111 | 111 |
|
112 | 112 |
|
| 113 | +def adam(f, x, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, |
| 114 | + maxiter=1000, tol=1e-16, callback=None): |
| 115 | + """ADAM method to minimize an objective function. |
| 116 | +
|
| 117 | + General implementation of ADAM for solving |
| 118 | +
|
| 119 | + .. math:: |
| 120 | + \min f(x) |
| 121 | +
|
| 122 | + The algorithm is intended for unconstrained problems. |
| 123 | +
|
| 124 | + The algorithm is described in |
| 125 | + `Adam: A Method for Stochastic Optimization |
| 126 | + <https://arxiv.org/abs/1412.6980>`_. All parameter names are taken from |
| 127 | + that article. |
| 128 | +
|
| 129 | + Parameters |
| 130 | + ---------- |
| 131 | + f : `Functional` |
| 132 | + Goal functional. Needs to have ``f.gradient``. |
| 133 | + x : ``f.domain`` element |
| 134 | + Starting point of the iteration |
| 135 | + learning_rate : float, optional |
| 136 | + Step length of the method. |
| 137 | + beta1 : float, optional |
| 138 | + Update rate for first order moment estimate. |
| 139 | + beta2 : float, optional |
| 140 | + Update rate for second order moment estimate. |
| 141 | + eps : float, optional |
| 142 | + A small constant for numerical stability. |
| 143 | + maxiter : int, optional |
| 144 | + Maximum number of iterations. |
| 145 | + tol : float, optional |
| 146 | + Tolerance that should be used for terminating the iteration. |
| 147 | + callback : callable, optional |
| 148 | + Object executing code per iteration, e.g. plotting each iterate |
| 149 | +
|
| 150 | + See Also |
| 151 | + -------- |
| 152 | + odl.solvers.smooth.gradient.steepest_descent : simple steepest descent |
| 153 | + odl.solvers.iterative.iterative.landweber : |
| 154 | + Optimized solver for the case ``f(x) = ||Ax - b||_2^2`` |
| 155 | + odl.solvers.iterative.iterative.conjugate_gradient : |
| 156 | + Optimized solver for the case ``f(x) = x^T Ax - 2 x^T b`` |
| 157 | + """ |
| 158 | + grad = f.gradient |
| 159 | + if x not in grad.domain: |
| 160 | + raise TypeError('`x` {!r} is not in the domain of `grad` {!r}' |
| 161 | + ''.format(x, grad.domain)) |
| 162 | + |
| 163 | + m = grad.domain.zero() |
| 164 | + v = grad.domain.zero() |
| 165 | + |
| 166 | + grad_x = grad.range.element() |
| 167 | + for _ in range(maxiter): |
| 168 | + grad(x, out=grad_x) |
| 169 | + |
| 170 | + if grad_x.norm() < tol: |
| 171 | + return |
| 172 | + |
| 173 | + m.lincomb(beta1, m, 1 - beta1, grad_x) |
| 174 | + v.lincomb(beta2, v, 1 - beta2, grad_x ** 2) |
| 175 | + |
| 176 | + step = learning_rate * np.sqrt(1 - beta2) / (1 - beta1) |
| 177 | + |
| 178 | + x.lincomb(1, x, -step, m / np.sqrt(v + eps)) |
| 179 | + |
| 180 | + if callback is not None: |
| 181 | + callback(x) |
| 182 | + |
| 183 | + |
113 | 184 | if __name__ == '__main__': |
114 | 185 | # pylint: disable=wrong-import-position |
115 | 186 | from odl.util.testutils import run_doctests |
|
0 commit comments