Skip to content

Commit

Permalink
fd: add preprocessign of args for FD on Mul
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Apr 23, 2020
1 parent 7109ddb commit 1928564
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
16 changes: 13 additions & 3 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def _new_from_self(self, **kwargs):
_kwargs = {'deriv_order': self.deriv_order, 'fd_order': self.fd_order,
'side': self.side, 'transpose': self.transpose, 'subs': self._subs,
'x0': self.x0, 'preprocessed': True}
expr = kwargs.pop('expr', self.expr)
_kwargs.update(**kwargs)
return Derivative(self.expr, *self.dims, **_kwargs)
return Derivative(expr, *self.dims, **_kwargs)

def subs(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -267,7 +268,14 @@ def _eval_at(self, func):
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
has to be computed at x=x + h_x/2.
"""
x0 = dict(zip(func.dimensions, func.indices_ref))
# Split the Add case to handle different staggereing
x0 = dict(func.indices_ref._getters)
if self.expr.is_Add:
args = [self._new_from_self(expr=a, x0=x0) for a in self.expr._args_diff]
args += [a for a in self.expr.args if a not in self.expr._args_diff]
return self.expr.func(*args)
elif self.expr.is_Mul:
return self._new_from_self(x0=x0, expr=self.expr._gather_for_diff)
return self._new_from_self(x0=x0)

@property
Expand All @@ -278,7 +286,7 @@ def evaluate(self):
return self._eval_fd(self.expr)

def _eval_fd(self, expr):
expr = getattr(expr, 'evaluate', expr)
expr = expr if expr.is_Mul else getattr(expr, 'evaluate', expr)
if self.side is not None and self.deriv_order == 1:
res = first_derivative(expr, self.dims[0], self.fd_order,
side=self.side, matvec=self.transpose,
Expand All @@ -289,6 +297,8 @@ def _eval_fd(self, expr):
else:
res = generic_derivative(expr, *self.dims, self.fd_order, self.deriv_order,
matvec=self.transpose, x0=self.x0)
if expr.is_Mul:
res = res.evaluate
for e in self._subs:
res = res.xreplace(e)
return res
41 changes: 40 additions & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def indices_ref(self):
"""The reference indices of the object (indices at first creation)."""
if len(self._args_diff) == 1:
return self._args_diff[0].indices_ref
return EnrichedTuple(*self.dimensions, getters=self.dimensions)
return EnrichedTuple(*self.indices, getters=self.dimensions)

@cached_property
def staggered(self):
Expand Down Expand Up @@ -306,6 +306,45 @@ class Mul(DifferentiableOp, sympy.Mul):
__sympy_class__ = sympy.Mul
__new__ = DifferentiableOp.__new__

@property
def _gather_for_diff(self):
"""
We handle Mul arguments by hand in case of staggered inputs
such as `f(x)*g(x + h_x/2)`
In that case, priority of indexing is applied to have single indices
The priority is from least to most:
- param
- NODE
- staggered
So for example f(x)*g(x + h_x/2) => .5*(f(x) + f(x + h_x))*g(x + h_x/2)
"""
if len(set(f.staggered for f in self._args_diff)) == 1:
return self

def stagg_prio(func):
if func._is_parameter or func.staggered is None:
return 0
elif func.staggered in self.dimensions:
return 2
else:
return 1

func_args = sorted(self._args_diff, key=stagg_prio, reverse=True)
new_args = []
ref_inds = func_args[0].indices_ref._getters

for f in self.args:
if f not in self._args_diff:
new_args += [f]
elif f is func_args[0]:
new_args += [f]
else:
ind_f = f.indices_ref._getters
new_args += [f.subs({ind_f.get(d, d): ref_inds.get(d, d)
for d in self.dimensions})]

return self.func(*new_args, evaluate=False)


class Pow(DifferentiableOp, sympy.Pow):
__sympy_class__ = sympy.Pow
Expand Down

0 comments on commit 1928564

Please sign in to comment.