Skip to content

Commit

Permalink
Merge pull request devitocodes#1247 from devitocodes/ind-ref-diff
Browse files Browse the repository at this point in the history
FD for composite staggered expr
  • Loading branch information
mloubout authored May 6, 2020
2 parents e918e26 + 31ba38d commit 6966dfb
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 78 deletions.
48 changes: 45 additions & 3 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Derivative(sympy.Derivative, Differentiable):
"""

_state = ('expr', 'dims', 'side', 'fd_order', 'transpose', '_subs', 'x0')
_fd_priority = 3

def __new__(cls, expr, *dims, **kwargs):
if type(expr) == sympy.Derivative:
Expand Down Expand Up @@ -197,8 +198,9 @@ def _new_from_self(self, **kwargs):
_kwargs = {'deriv_order': self.deriv_order, 'fd_order': self.fd_order,
'side': self.side, 'transpose': self.transpose, 'subs': self._subs,
'x0': self.x0, 'preprocessed': True}
expr = kwargs.pop('expr', self.expr)
_kwargs.update(**kwargs)
return Derivative(self.expr, *self.dims, **_kwargs)
return Derivative(expr, *self.dims, **_kwargs)

def subs(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -267,7 +269,24 @@ def _eval_at(self, func):
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
has to be computed at x=x + h_x/2.
"""
x0 = dict(zip(func.dimensions, func.indices_ref))
x0 = dict(func.indices_ref._getters)
if self.expr.is_Add:
# Derivatives are linear and the derivative of an Add can be treated as an
# Add of derivatives which makes (u(x + h_x/2) + v(x)).dx` easier to handle
# since u(x + h_x/2) and v(x) require different indices
# for the finite difference.
args = [self._new_from_self(expr=a, x0=x0) if a in self.expr._args_diff else a
for a in self.expr.args]
return self.expr.func(*args)
elif self.expr.is_Mul:
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
# in most equation with div(a * u) for example. The expression is re-centered
# at the highest priority index (see _gather_for_diff) to compute the
# derivative at x0.
return self._new_from_self(x0=x0, expr=self.expr._gather_for_diff)
# For every other cases, that has more functions or more complexe arithmetic,
# there is not actual way to decide what to do so it’s as safe to use
# the expression as is.
return self._new_from_self(x0=x0)

@property
Expand All @@ -277,8 +296,26 @@ def evaluate(self):
# types of discretizations.
return self._eval_fd(self.expr)

@property
def _eval_deriv(self):
return self._eval_fd(self.expr)

def _eval_fd(self, expr):
expr = getattr(expr, 'evaluate', expr)
"""
Evaluate finite difference approximation of the Derivative.
Evaluation is carried out via the following four steps:
- 1: Evaluate derivatives within the expression. For example given
`f.dx * g`, `f.dx` will be evaluated first.
- 2: Evaluate the finite difference for the (new) expression.
- 3: Evaluate remaining terms (as `g` may need to be evaluated
at a different point).
- 4: Apply substitutions.
"""
# Step 1: Evaluate derivatives within expression
expr = getattr(expr, '_eval_deriv', expr)

# Step 2: Evaluate FD of the new expression
if self.side is not None and self.deriv_order == 1:
res = first_derivative(expr, self.dims[0], self.fd_order,
side=self.side, matvec=self.transpose,
Expand All @@ -289,6 +326,11 @@ def _eval_fd(self, expr):
else:
res = generic_derivative(expr, *self.dims, self.fd_order, self.deriv_order,
matvec=self.transpose, x0=self.x0)

# Step 3: Evaluate remaining part of expression
res = res.evaluate

# Step 4: Apply substitution
for e in self._subs:
res = res.xreplace(e)
return res
68 changes: 66 additions & 2 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from cached_property import cached_property
from devito.finite_differences.lazy import Evaluable
from devito.logger import warning
from devito.tools import EnrichedTuple, filter_ordered, flatten
from devito.tools import filter_ordered, flatten
from devito.types.utils import DimensionTuple

__all__ = ['Differentiable']

Expand Down Expand Up @@ -86,7 +87,11 @@ def dimensions(self):
@property
def indices_ref(self):
"""The reference indices of the object (indices at first creation)."""
return EnrichedTuple(*self.dimensions, getters=self.dimensions)
if len(self._args_diff) == 1:
return self._args_diff[0].indices_ref
elif len(self._args_diff) == 0:
return DimensionTuple(*self.dimensions, getters=self.dimensions)
return highest_priority(self).indices_ref

@cached_property
def staggered(self):
Expand Down Expand Up @@ -115,6 +120,14 @@ def _eval_at(self, func):
return self
return self.func(*[getattr(a, '_eval_at', lambda x: a)(func) for a in self.args])

@property
def _eval_deriv(self):
return self.func(*[getattr(a, '_eval_deriv', a) for a in self.args])

@property
def _fd_priority(self):
return .75 if self.is_TimeDependent else .5

def __hash__(self):
return super(Differentiable, self).__hash__()

Expand Down Expand Up @@ -252,6 +265,11 @@ def _has(self, pattern):
return super(Differentiable, self)._has(pattern)


def highest_priority(DiffOp):
prio = lambda x: getattr(x, '_fd_priority', 0)
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


class DifferentiableOp(Differentiable):

__sympy_class__ = None
Expand All @@ -266,6 +284,14 @@ def __new__(cls, *args, **kwargs):

return obj

def subs(self, *args, **kwargs):
return self.func(*[getattr(a, 'subs', lambda x: a)(*args, **kwargs)
for a in self.args], evaluate=False)

@property
def _gather_for_diff(self):
return self

# Bypass useless expensive SymPy _eval_ methods, for which we either already
# know or don't care about the answer, because it'd have ~zero impact on our
# average expressions
Expand Down Expand Up @@ -304,8 +330,46 @@ class Mul(DifferentiableOp, sympy.Mul):
__sympy_class__ = sympy.Mul
__new__ = DifferentiableOp.__new__

@property
def _gather_for_diff(self):
"""
We handle Mul arguments by hand in case of staggered inputs
such as `f(x)*g(x + h_x/2)` that will be transformed into
f(x + h_x/2)*g(x + h_x/2) and priority of indexing is applied
to have single indices as in this example.
The priority is from least to most:
- param
- NODE
- staggered
"""

if len(set(f.staggered for f in self._args_diff)) == 1:
return self

func_args = highest_priority(self)
new_args = []
ref_inds = func_args.indices_ref._getters

for f in self.args:
if f not in self._args_diff:
new_args.append(f)
elif f is func_args:
new_args.append(f)
else:
ind_f = f.indices_ref._getters
mapper = {ind_f.get(d, d): ref_inds.get(d, d)
for d in self.dimensions
if ind_f.get(d, d) is not ref_inds.get(d, d)}
if mapper:
new_args.append(f.subs(mapper))
else:
new_args.append(f)

return self.func(*new_args, evaluate=False)


class Pow(DifferentiableOp, sympy.Pow):
_fd_priority = 0
__sympy_class__ = sympy.Pow
__new__ = DifferentiableOp.__new__

Expand Down
6 changes: 3 additions & 3 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def generic_derivative(expr, dim, fd_order, deriv_order, symbolic=False,
def indices_weights_to_fd(expr, dim, inds, weights, matvec=1):
"""Expression from lists of indices and weights."""
diff = dim.spacing
all_dims = tuple(set((expr.indices_ref[dim],) + tuple(expr.indices_ref[dim]
for i in expr.dimensions if i.root is dim)))

d0 = ([d for d in expr.dimensions if d.root is dim] or [dim])[0]

mapper = {dim: d0, diff: matvec*diff}
Expand All @@ -275,8 +274,9 @@ def indices_weights_to_fd(expr, dim, inds, weights, matvec=1):
iloc = i.xreplace(mapper)
except AttributeError:
iloc = i
subs = dict((d, iloc) for d in all_dims)
subs = {expr.indices_ref[dim]: iloc}
terms.append(expr.subs(subs) * c)

deriv = Add(*terms)

return deriv.evalf(_PRECISION)
13 changes: 12 additions & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.types.equation import Eq

__all__ = ['yreplace', 'xreplace_indices', 'pow_to_mul', 'as_symbol', 'indexify',
'split_affine', 'uxreplace']
'split_affine', 'uxreplace', 'aligned_indices']


def yreplace(exprs, make, rule=None, costmodel=lambda e: True, repeat=False, eager=False):
Expand Down Expand Up @@ -344,3 +344,14 @@ def indexify(expr):
except AttributeError:
pass
return expr.xreplace(mapper)


def aligned_indices(i, j, spacing):
"""
Check if two indices are aligned. Two indices are aligned if they
differ by an Integer*spacing.
"""
try:
return int((i - j)/spacing) == (i - j)/spacing
except TypeError:
return False
17 changes: 17 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from devito.data import default_allocator
from devito.finite_differences import Evaluable
from devito.parameters import configuration
from devito.symbolics import aligned_indices
from devito.tools import Pickable, ctypes_to_cstr, dtype_to_cstr, dtype_to_ctype
from devito.types.args import ArgProvider
from devito.types.caching import Cached
Expand Down Expand Up @@ -724,9 +725,25 @@ def dimensions(self):
"""Tuple of Dimensions representing the object indices."""
return self._dimensions

@property
def _eval_deriv(self):
return self

@cached_property
def _is_on_grid(self):
"""
Check whether the object is on the grid or need averaging.
For example, if the original non-staggered function is f(x)
then f(x) is on the grid and f(x + h_x/2) is off the grid.
"""
return all([aligned_indices(i, j, d.spacing) for i, j, d in
zip(self.indices, self.indices_ref, self.dimensions)])

@property
def evaluate(self):
# Average values if at a location not on the Function's grid
if self._is_on_grid:
return self
weight = 1.0
avg_list = [self]
is_averaged = False
Expand Down
1 change: 1 addition & 0 deletions devito/types/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init_cached__(self, key):
The cache key of the object whose state is used to initialize `self`.
It must be hashable.
"""

self.__dict__ = _SymbolCache[key]().__dict__

def __hash__(self):
Expand Down
17 changes: 14 additions & 3 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,16 +951,23 @@ def __init_finalize__(self, *args, **kwargs):
# parameter has to be computed at x + hx/2)
self._is_parameter = kwargs.get('parameter', False)

@cached_property
def _fd_priority(self):
return 1 if self.staggered in [NODE, None] else 2

@property
def is_parameter(self):
return self._is_parameter

def _eval_at(self, func):
if not self.is_parameter or self.staggered == func.staggered:
return self

return self.subs({self.indices_ref[d1]: func.indices_ref[d1]
for d1 in self.dimensions})
mapper = {self.indices_ref[d]: func.indices_ref[d]
for d in self.dimensions
if self.indices_ref[d] is not func.indices_ref[d]}
if mapper:
return self.subs(mapper)
return self

@classmethod
def __indices_setup__(cls, **kwargs):
Expand Down Expand Up @@ -1300,6 +1307,10 @@ def __shape_setup__(cls, **kwargs):
raise TypeError("`save` can be None, int or Buffer, not %s" % type(save))
return tuple(shape)

@cached_property
def _fd_priority(self):
return super(TimeFunction, self)._fd_priority + .1

@property
def time_order(self):
"""The time order."""
Expand Down
8 changes: 4 additions & 4 deletions examples/cfd/02_convection_nonlinear.ipynb

Large diffs are not rendered by default.

234 changes: 174 additions & 60 deletions examples/cfd/06_poisson.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def test_fd_adjoint(self, so, ndim, derivative, adjoint_name):
deriv = getattr(f, derivative)
coeff = 1 if derivative == 'dx2' else -1
expected = coeff * getattr(f, derivative).evaluate.subs({x.spacing: -x.spacing})
assert deriv.T.evaluate == expected
assert simplify(deriv.T.evaluate) == simplify(expected)

# Compute numerical derivatives and verify dot test
# i.e <f.dx, g> = <f, g.dx.T>
Expand Down
Loading

0 comments on commit 6966dfb

Please sign in to comment.