Skip to content

Commit 77d83d0

Browse files
committed
ENH: Add ADAM solver, see #984
1 parent 009e19b commit 77d83d0

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

odl/solvers/smooth/gradient.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from odl.solvers.util import ConstantLineSearch
2828

2929

30-
__all__ = ('steepest_descent',)
30+
__all__ = ('steepest_descent', 'adam')
3131

3232

3333
# TODO: update all docs
@@ -110,6 +110,77 @@ def steepest_descent(f, x, line_search=1.0, maxiter=1000, tol=1e-16,
110110
callback(x)
111111

112112

113+
def adam(f, x, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8,
114+
maxiter=1000, tol=1e-16, callback=None):
115+
"""ADAM method to minimize an objective function.
116+
117+
General implementation of ADAM for solving
118+
119+
.. math::
120+
\min f(x)
121+
122+
The algorithm is intended for unconstrained problems.
123+
124+
The algorithm is described in
125+
`Adam: A Method for Stochastic Optimization
126+
<https://arxiv.org/abs/1412.6980>`_. All parameter names are taken from
127+
that article.
128+
129+
Parameters
130+
----------
131+
f : `Functional`
132+
Goal functional. Needs to have ``f.gradient``.
133+
x : ``f.domain`` element
134+
Starting point of the iteration
135+
learning_rate : float, optional
136+
Step length of the method.
137+
beta1 : float, optional
138+
Update rate for first order moment estimate.
139+
beta2 : float, optional
140+
Update rate for second order moment estimate.
141+
eps : float, optional
142+
A small constant for numerical stability.
143+
maxiter : int, optional
144+
Maximum number of iterations.
145+
tol : float, optional
146+
Tolerance that should be used for terminating the iteration.
147+
callback : callable, optional
148+
Object executing code per iteration, e.g. plotting each iterate
149+
150+
See Also
151+
--------
152+
odl.solvers.smooth.gradient.steepest_descent : simple steepest descent
153+
odl.solvers.iterative.iterative.landweber :
154+
Optimized solver for the case ``f(x) = ||Ax - b||_2^2``
155+
odl.solvers.iterative.iterative.conjugate_gradient :
156+
Optimized solver for the case ``f(x) = x^T Ax - 2 x^T b``
157+
"""
158+
grad = f.gradient
159+
if x not in grad.domain:
160+
raise TypeError('`x` {!r} is not in the domain of `grad` {!r}'
161+
''.format(x, grad.domain))
162+
163+
m = grad.domain.zero()
164+
v = grad.domain.zero()
165+
166+
grad_x = grad.range.element()
167+
for _ in range(maxiter):
168+
grad(x, out=grad_x)
169+
170+
if grad_x.norm() < tol:
171+
return
172+
173+
m.lincomb(beta1, m, 1 - beta1, grad_x)
174+
v.lincomb(beta2, v, 1 - beta2, grad_x ** 2)
175+
176+
step = learning_rate * np.sqrt(1 - beta2) / (1 - beta1)
177+
178+
x.lincomb(1, x, -step, m / np.sqrt(v + eps))
179+
180+
if callback is not None:
181+
callback(x)
182+
183+
113184
if __name__ == '__main__':
114185
# pylint: disable=wrong-import-position
115186
from odl.util.testutils import run_doctests

odl/test/solvers/iterative/iterative_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
# Find the valid projectors
3232
@pytest.fixture(scope="module",
3333
params=['steepest_descent',
34+
'adam',
3435
'landweber',
3536
'conjugate_gradient',
3637
'conjugate_gradient_normal',
@@ -47,6 +48,12 @@ def solver(op, x, rhs):
4748
func = odl.solvers.L2NormSquared(op.domain) * (op - rhs)
4849

4950
odl.solvers.steepest_descent(func, x, line_search=0.5 / norm2)
51+
elif solver_name == 'adam':
52+
def solver(op, x, rhs):
53+
norm2 = op.adjoint(op(x)).norm() / x.norm()
54+
func = odl.solvers.L2NormSquared(op.domain) * (op - rhs)
55+
56+
odl.solvers.adam(func, x, learning_rate=4.0 / norm2, maxiter=150)
5057
elif solver_name == 'landweber':
5158
def solver(op, x, rhs):
5259
norm2 = op.adjoint(op(x)).norm() / x.norm()

odl/test/solvers/smooth/smooth_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ def test_steepest_descent(functional):
136136
assert functional(x) < 1e-3
137137

138138

139+
def test_adam(functional):
140+
"""Test the ``adam`` solver."""
141+
142+
x = functional.domain.one()
143+
odl.solvers.adam(functional, x, tol=1e-2, learning_rate=0.5)
144+
145+
assert functional(x) < 1e-3
146+
147+
139148
def test_conjguate_gradient_nonlinear(functional, nonlinear_cg_beta):
140149
"""Test the ``conjugate_gradient_nonlinear`` solver."""
141150
line_search = odl.solvers.BacktrackingLineSearch(functional)

0 commit comments

Comments
 (0)