Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FD for composite staggered expr #1247

Merged
merged 7 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
types: rework sub to keep track on/off grid
  • Loading branch information
mloubout committed May 6, 2020
commit 17a3d5b7fcdf87576a7a23b62caab947b4ad685b
24 changes: 12 additions & 12 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,19 +291,19 @@ def _eval_deriv(self):

def _eval_fd(self, expr):
"""
valuate finite difference approximation of the Derivative.
Evaluate finite difference approximation of the Derivative.
The evaluation goes in four steps:
- 1: evaluate derivatives within the expression. For example `f.dx * g` will
- 1: Evaluate derivatives within the expression. For example `f.dx * g` will
evaluate `f.dx` first.
- 2: Evaluate the finite difference for the (new) expression
- 3: Evaluate remaining (as `g` may need to be evaluated at a different point)
- 4: apply subsititutions.
- 2: Evaluate the finite difference for the (new) expression.
- 3: Evaluate remaining (as `g` may need to be evaluated at a different point).
- 4: Apply substitutions.

"""
# Evaluate derivatives in the expression
was_mul = expr.is_Mul
# Step 1: Evaluate derivatives within expression
expr = getattr(expr, '_eval_deriv', expr)
# Evaluate FD of the new expressiob

# Step 2: Evaluate FD of the new expression
if self.side is not None and self.deriv_order == 1:
res = first_derivative(expr, self.dims[0], self.fd_order,
side=self.side, matvec=self.transpose,
Expand All @@ -314,10 +314,10 @@ def _eval_fd(self, expr):
else:
res = generic_derivative(expr, *self.dims, self.fd_order, self.deriv_order,
matvec=self.transpose, x0=self.x0)
# Evaluate remaining part of expression
if was_mul:
res = res.evaluate
# Apply substitution

# Step 3: Evaluate remaining part of expression
res = res.evaluate
# Step 4: Apply substitution
for e in self._subs:
res = res.xreplace(e)
return res
13 changes: 9 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def indices_ref(self):
"""The reference indices of the object (indices at first creation)."""
if len(self._args_diff) == 1:
return self._args_diff[0].indices_ref
return EnrichedTuple(*self.indices, getters=self.dimensions)
return EnrichedTuple(*self.dimensions, getters=self.dimensions)

@cached_property
def staggered(self):
Expand Down Expand Up @@ -320,7 +320,7 @@ def _gather_for_diff(self):
- param
- NODE
- staggered
So for example f(x)*g(x + h_x/2) => .5*(f(x) + f(x + h_x))*g(x + h_x/2)
So for example f(x)*g(x + h_x/2) => f(x + h_x/2)*g(x + h_x/2)
"""

if len(set(f.staggered for f in self._args_diff)) == 1:
Expand All @@ -347,8 +347,13 @@ def stagg_prio(func):
new_args.append(f)
else:
ind_f = f.indices_ref._getters
new_args.append(f.subs({ind_f.get(d, d): ref_inds.get(d, d)
for d in self.dimensions}))
mapper = {ind_f.get(d, d): ref_inds.get(d, d)
for d in self.dimensions
if ind_f.get(d, d) is not ref_inds.get(d, d)}
if mapper:
new_args.append(f.subs(mapper, on_grid=False))
else:
new_args.append(f)

return self.func(*new_args, evaluate=False)

Expand Down
1 change: 1 addition & 0 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def indices_weights_to_fd(expr, dim, inds, weights, matvec=1):
iloc = i
subs = dict((d, iloc) for d in all_dims)
terms.append(expr.subs(subs) * c)

deriv = Add(*terms)

return deriv.evalf(_PRECISION)
21 changes: 20 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def __new__(cls, *args, **kwargs):

if obj is not None:
newobj = sympy.Function.__new__(cls, *args, **options)
newobj.__init_cached__(key)
newobj.__init_cached__(key, ignore=['_on_grid'])
return newobj

# Not in cache. Create a new Function via sympy.Function
Expand Down Expand Up @@ -664,6 +664,12 @@ def __init_finalize__(self, *args, **kwargs):

__hash__ = Cached.__hash__

def subs(self, *args, **kwargs):
on_grid = kwargs.pop('on_grid', True)
newobj = super(AbstractFunction, self).subs(*args, **kwargs)
newobj._on_grid = on_grid
return newobj

@classmethod
def __indices_setup__(cls, **kwargs):
"""Extract the object indices from ``kwargs``."""
Expand Down Expand Up @@ -724,9 +730,22 @@ def dimensions(self):
"""Tuple of Dimensions representing the object indices."""
return self._dimensions

@property
def _eval_deriv(self):
return self

@property
def on_grid(self):
try:
return self._on_grid
except AttributeError:
return True

@property
def evaluate(self):
# Average values if at a location not on the Function's grid
if self.on_grid:
return self
weight = 1.0
avg_list = [self]
is_averaged = False
Expand Down
6 changes: 4 additions & 2 deletions devito/types/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, key):
# Add ourselves to the symbol cache
_SymbolCache[key] = AugmentedWeakRef(self, self._cache_meta())

def __init_cached__(self, key):
def __init_cached__(self, key, ignore=None):
"""
Initialise `self` with a cached object state.

Expand All @@ -82,7 +82,9 @@ def __init_cached__(self, key):
The cache key of the object whose state is used to initialize `self`.
It must be hashable.
"""
self.__dict__ = _SymbolCache[key]().__dict__
ignore = ignore or []
self.__dict__ = {k: v for k, v in _SymbolCache[key]().__dict__.items()
if k not in ignore}

def __hash__(self):
"""
Expand Down
9 changes: 6 additions & 3 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,9 +958,12 @@ def is_parameter(self):
def _eval_at(self, func):
if not self.is_parameter or self.staggered == func.staggered:
return self

return self.subs({self.indices_ref[d1]: func.indices_ref[d1]
for d1 in self.dimensions})
mapper = {self.indices_ref[d1]: func.indices_ref[d1]
for d1 in self.dimensions
if self.indices_ref[d1] is not func.indices_ref[d1]}
if mapper:
return self.subs(mapper, on_grid=False)
return self

@classmethod
def __indices_setup__(cls, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions examples/cfd/02_convection_nonlinear.ipynb

Large diffs are not rendered by default.

234 changes: 174 additions & 60 deletions examples/cfd/06_poisson.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/test_staggered_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_avg(ndim):
# f at nod (x, y, z)
shifted = f
for dd in d:
shifted = shifted.subs({dd: dd - dd.spacing/2})
shifted = shifted.subs({dd: dd - dd.spacing/2}, on_grid=False)
assert all(i == dd for i, dd in zip(shifted.indices, grid.dimensions))
# Average automatically i.e.:
# f not defined at x so f(x, y) = 0.5*f(x - h_x/2, y) + 0.5*f(x + h_x/2, y)
Expand Down