|
8 | 8 | # v. 2.0. If a copy of the MPL was not distributed with this file, You can |
9 | 9 | # obtain one at https://mozilla.org/MPL/2.0/. |
10 | 10 |
|
11 | | -from __future__ import print_function, division, absolute_import |
| 11 | +from __future__ import absolute_import, division, print_function |
| 12 | + |
12 | 13 | import numpy as np |
13 | 14 |
|
| 15 | +from odl.operator.default_ops import ( |
| 16 | + ConstantOperator, IdentityOperator, InnerProductOperator) |
14 | 17 | from odl.operator.operator import ( |
15 | | - Operator, OperatorComp, OperatorLeftScalarMult, OperatorRightScalarMult, |
16 | | - OperatorRightVectorMult, OperatorSum, OperatorPointwiseProduct) |
17 | | -from odl.operator.default_ops import (IdentityOperator, ConstantOperator) |
18 | | -from odl.solvers.nonsmooth import (proximal_arg_scaling, proximal_translation, |
19 | | - proximal_quadratic_perturbation, |
20 | | - proximal_const_func, proximal_convex_conj) |
21 | | -from odl.util import signature_string, indent |
22 | | - |
| 18 | + Operator, OperatorComp, OperatorLeftScalarMult, OperatorPointwiseProduct, |
| 19 | + OperatorRightScalarMult, OperatorRightVectorMult, OperatorSum) |
| 20 | +from odl.solvers.nonsmooth import ( |
| 21 | + proximal_arg_scaling, proximal_const_func, proximal_convex_conj, |
| 22 | + proximal_quadratic_perturbation, proximal_translation) |
| 23 | +from odl.util import indent, signature_string |
23 | 24 |
|
24 | 25 | __all__ = ('Functional', 'FunctionalLeftScalarMult', |
25 | 26 | 'FunctionalRightScalarMult', 'FunctionalComp', |
@@ -204,7 +205,7 @@ def derivative(self, point): |
204 | 205 | ------- |
205 | 206 | derivative : `Operator` |
206 | 207 | """ |
207 | | - return self.gradient(point).T |
| 208 | + return InnerProductOperator(self.domain, self.gradient(point)) |
208 | 209 |
|
209 | 210 | def translated(self, shift): |
210 | 211 | """Return a translation of the functional. |
@@ -1399,33 +1400,18 @@ def __init__(self, functional, point, subgrad): |
1399 | 1400 | raise TypeError('`functional` {} not an instance of ``Functional``' |
1400 | 1401 | ''.format(functional)) |
1401 | 1402 | self.__functional = functional |
1402 | | - |
1403 | | - if point not in functional.domain: |
1404 | | - raise ValueError('`point` {} is not in `functional.domain` {}' |
1405 | | - ''.format(point, functional.domain)) |
1406 | | - self.__point = point |
1407 | | - |
1408 | | - if subgrad not in functional.domain: |
1409 | | - raise TypeError( |
1410 | | - '`subgrad` must be an element in `functional.domain`, got ' |
1411 | | - '{}'.format(subgrad)) |
1412 | | - self.__subgrad = subgrad |
1413 | | - |
1414 | | - self.__constant = ( |
1415 | | - -functional(point) |
1416 | | - + functional.domain.inner(subgrad, point) |
1417 | | - ) |
1418 | | - |
| 1403 | + space = functional.domain |
| 1404 | + self.__point = space.element(point) |
| 1405 | + self.__subgrad = space.element(subgrad) |
| 1406 | + self.__constant = -functional(point) + space.inner(subgrad, point) |
1419 | 1407 | self.__bregman_dist = FunctionalQuadraticPerturb( |
1420 | | - functional, linear_term=-subgrad, constant=self.__constant) |
1421 | | - |
1422 | | - grad_lipschitz = ( |
1423 | | - functional.grad_lipschitz + functional.domain.norm(subgrad) |
| 1408 | + functional, linear_term=-subgrad, constant=self.__constant |
1424 | 1409 | ) |
| 1410 | + grad_lipschitz = functional.grad_lipschitz + space.norm(subgrad) |
1425 | 1411 |
|
1426 | 1412 | super(BregmanDistance, self).__init__( |
1427 | | - space=functional.domain, linear=False, |
1428 | | - grad_lipschitz=grad_lipschitz) |
| 1413 | + space, linear=False, grad_lipschitz=grad_lipschitz |
| 1414 | + ) |
1429 | 1415 |
|
1430 | 1416 | @property |
1431 | 1417 | def functional(self): |
@@ -1459,15 +1445,10 @@ def proximal(self): |
1459 | 1445 | @property |
1460 | 1446 | def gradient(self): |
1461 | 1447 | """Gradient operator of the functional.""" |
1462 | | - try: |
1463 | | - op_to_return = self.functional.gradient |
1464 | | - except NotImplementedError: |
1465 | | - raise NotImplementedError( |
1466 | | - '`self.functional.gradient` is not implemented for ' |
1467 | | - '`self.functional` {}'.format(self.functional)) |
1468 | | - |
1469 | | - op_to_return = op_to_return - ConstantOperator(self.subgrad) |
1470 | | - return op_to_return |
| 1448 | + return ( |
| 1449 | + self.functional.gradient |
| 1450 | + - ConstantOperator(self.domain, self.subgrad) |
| 1451 | + ) |
1471 | 1452 |
|
1472 | 1453 | def __repr__(self): |
1473 | 1454 | '''Return ``repr(self)``.''' |
|
0 commit comments