Skip to content

Commit a188aa8

Browse files
committed
WIP: make functionals and prox ops work for product spaces
1 parent 2efc752 commit a188aa8

File tree

3 files changed

+49
-56
lines changed

3 files changed

+49
-56
lines changed

odl/solvers/functional/default_functionals.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def __init__(self):
148148

149149
def _call(self, x):
150150
"""Apply the gradient operator to the given point."""
151-
return np.sign(x)
151+
if isinstance(self.domain, ProductSpace):
152+
return self.domain.apply(np.sign, x)
153+
else:
154+
return np.sign(x)
152155

153156
def derivative(self, x):
154157
"""Derivative is a.e. zero."""
@@ -1127,14 +1130,23 @@ def _call(self, x):
11271130
import scipy.special
11281131

11291132
if self.prior is None:
1130-
tmp = self.domain.inner(self.domain.one(), x - 1 - np.log(x))
1133+
if isinstance(self.domain, ProductSpace):
1134+
log_x = self.domain.apply(np.log, x)
1135+
else:
1136+
log_x = np.log(x)
1137+
tmp = self.domain.inner(self.domain.one(), x - 1 - log_x)
1138+
11311139
else:
1132-
tmp = self.domain.inner(
1133-
self.domain.one(),
1134-
x - self.prior + scipy.special.xlogy(
1135-
self.prior, self.prior / x
1136-
),
1137-
)
1140+
g = self.prior
1141+
if isinstance(self.domain, ProductSpace):
1142+
xlogy = self.domain.apply2(
1143+
lambda v, i: scipy.special.xlogy(g[i], g[i] / v), x
1144+
)
1145+
else:
1146+
xlogy = scipy.special.xlogy(g, g / x)
1147+
1148+
tmp = self.domain.inner(self.domain.one(), x - g + xlogy)
1149+
11381150
if np.isnan(tmp):
11391151
# In this case, some element was less than or equal to zero
11401152
return np.inf

odl/solvers/functional/functional.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

11-
from __future__ import print_function, division, absolute_import
11+
from __future__ import absolute_import, division, print_function
12+
1213
import numpy as np
1314

15+
from odl.operator.default_ops import (
16+
ConstantOperator, IdentityOperator, InnerProductOperator)
1417
from odl.operator.operator import (
15-
Operator, OperatorComp, OperatorLeftScalarMult, OperatorRightScalarMult,
16-
OperatorRightVectorMult, OperatorSum, OperatorPointwiseProduct)
17-
from odl.operator.default_ops import (IdentityOperator, ConstantOperator)
18-
from odl.solvers.nonsmooth import (proximal_arg_scaling, proximal_translation,
19-
proximal_quadratic_perturbation,
20-
proximal_const_func, proximal_convex_conj)
21-
from odl.util import signature_string, indent
22-
18+
Operator, OperatorComp, OperatorLeftScalarMult, OperatorPointwiseProduct,
19+
OperatorRightScalarMult, OperatorRightVectorMult, OperatorSum)
20+
from odl.solvers.nonsmooth import (
21+
proximal_arg_scaling, proximal_const_func, proximal_convex_conj,
22+
proximal_quadratic_perturbation, proximal_translation)
23+
from odl.util import indent, signature_string
2324

2425
__all__ = ('Functional', 'FunctionalLeftScalarMult',
2526
'FunctionalRightScalarMult', 'FunctionalComp',
@@ -204,7 +205,7 @@ def derivative(self, point):
204205
-------
205206
derivative : `Operator`
206207
"""
207-
return self.gradient(point).T
208+
return InnerProductOperator(self.domain, self.gradient(point))
208209

209210
def translated(self, shift):
210211
"""Return a translation of the functional.
@@ -1399,33 +1400,18 @@ def __init__(self, functional, point, subgrad):
13991400
raise TypeError('`functional` {} not an instance of ``Functional``'
14001401
''.format(functional))
14011402
self.__functional = functional
1402-
1403-
if point not in functional.domain:
1404-
raise ValueError('`point` {} is not in `functional.domain` {}'
1405-
''.format(point, functional.domain))
1406-
self.__point = point
1407-
1408-
if subgrad not in functional.domain:
1409-
raise TypeError(
1410-
'`subgrad` must be an element in `functional.domain`, got '
1411-
'{}'.format(subgrad))
1412-
self.__subgrad = subgrad
1413-
1414-
self.__constant = (
1415-
-functional(point)
1416-
+ functional.domain.inner(subgrad, point)
1417-
)
1418-
1403+
space = functional.domain
1404+
self.__point = space.element(point)
1405+
self.__subgrad = space.element(subgrad)
1406+
self.__constant = -functional(point) + space.inner(subgrad, point)
14191407
self.__bregman_dist = FunctionalQuadraticPerturb(
1420-
functional, linear_term=-subgrad, constant=self.__constant)
1421-
1422-
grad_lipschitz = (
1423-
functional.grad_lipschitz + functional.domain.norm(subgrad)
1408+
functional, linear_term=-subgrad, constant=self.__constant
14241409
)
1410+
grad_lipschitz = functional.grad_lipschitz + space.norm(subgrad)
14251411

14261412
super(BregmanDistance, self).__init__(
1427-
space=functional.domain, linear=False,
1428-
grad_lipschitz=grad_lipschitz)
1413+
space, linear=False, grad_lipschitz=grad_lipschitz
1414+
)
14291415

14301416
@property
14311417
def functional(self):
@@ -1459,15 +1445,10 @@ def proximal(self):
14591445
@property
14601446
def gradient(self):
14611447
"""Gradient operator of the functional."""
1462-
try:
1463-
op_to_return = self.functional.gradient
1464-
except NotImplementedError:
1465-
raise NotImplementedError(
1466-
'`self.functional.gradient` is not implemented for '
1467-
'`self.functional` {}'.format(self.functional))
1468-
1469-
op_to_return = op_to_return - ConstantOperator(self.subgrad)
1470-
return op_to_return
1448+
return (
1449+
self.functional.gradient
1450+
- ConstantOperator(self.domain, self.subgrad)
1451+
)
14711452

14721453
def __repr__(self):
14731454
'''Return ``repr(self)``.'''

odl/solvers/nonsmooth/proximal_operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
Foundations and Trends in Optimization, 1 (2014), pp 127-239.
2222
"""
2323

24-
from __future__ import print_function, division, absolute_import
24+
from __future__ import absolute_import, division, print_function
25+
2526
import numpy as np
2627

2728
from odl.operator import (
28-
Operator, IdentityOperator, ConstantOperator, DiagonalOperator,
29-
PointwiseNorm, MultiplyOperator)
29+
ConstantOperator, DiagonalOperator, IdentityOperator, MultiplyOperator,
30+
Operator, PointwiseNorm)
3031
from odl.space import ProductSpace
3132

32-
3333
__all__ = ('combine_proximals', 'proximal_convex_conj', 'proximal_translation',
3434
'proximal_arg_scaling', 'proximal_quadratic_perturbation',
3535
'proximal_composition', 'proximal_const_func',
@@ -799,7 +799,7 @@ def _call(self, x, out):
799799
if step < 1.0:
800800
self.range.lincomb(1 - step, x, out=out)
801801
else:
802-
out[:] = 0
802+
self.range.lincomb(0, out, out=out)
803803

804804
else:
805805
x_norm = self.domain.norm(x - g) * (1 + eps)
@@ -811,7 +811,7 @@ def _call(self, x, out):
811811
if step < 1.0:
812812
self.range.lincomb(1 - step, x, step, g, out=out)
813813
else:
814-
out[:] = g
814+
self.range.lincomb(1, g, out=out)
815815

816816
return ProximalL2
817817

0 commit comments

Comments
 (0)