Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,3 +2292,90 @@ def apply_coordinate_derivatives(expression):
"""Apply coordinate derivatives to an expression."""
dag_traverser = CoordinateDerivativeRuleDispatcher()
return map_integrands(dag_traverser, expression)


class CoefficientDerivativeRuleDispatcher(DAGTraverser):
"""Dispatcher."""

def __init__(
self,
compress: bool | None = True,
visited_cache: dict[tuple, Expr] | None = None,
result_cache: dict[Expr, Expr] | None = None,
) -> None:
"""Initialise."""
super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
# Record the operations delayed to the derivative expansion phase:
# Example: dN(u)/du where `N` is a BaseFormOperator and `u` a Coefficient
self.pending_operations = ()
# Create DAGTraverser caches.
self._dag_traverser_cache: dict[
tuple[type, Expr] | tuple[type, Expr, Expr, Expr] | tuple[type, Expr, Expr, Expr, Expr],
DAGTraverser,
] = {}

@singledispatchmethod
def process(self, o: Expr) -> Expr:
"""Process ``o``.

Args:
o: `Expr` to be processed.

Returns:
Processed object.

"""
return super().process(o)

@process.register(Expr)
@process.register(BaseForm) # type: ignore
def _(self, o: Expr | BaseForm) -> Expr | BaseForm:
"""Apply to expr and base form."""
return self.reuse_if_untouched(o)

@process.register(Terminal)
def _(self, o: Terminal) -> Terminal:
"""Apply to a terminal."""
return o

@process.register(CoefficientDerivative)
@DAGTraverser.postorder_only_children([0])
def _(self, o: CoefficientDerivative, f: Expr | BaseForm) -> Expr | BaseForm:
"""Apply to a coefficient_derivative."""
_, w, v, cd = o.ufl_operands
key = (GateauxDerivativeRuleset, w, v, cd)
# We need to go through the dag first to record the pending
# operations
dag_traverser = self._dag_traverser_cache.setdefault(
key,
GateauxDerivativeRuleset(w, v, cd), # type: ignore
)
# If f has been seen by the traverser, it immediately returns
# the cached value.
mapped_expr = dag_traverser(f) # type: ignore
# Need to account for pending operations that have been stored
# in other integrands
self.pending_operations += dag_traverser.pending_operations # type: ignore
return mapped_expr

@process.register(Indexed)
@DAGTraverser.postorder
def _(self, o: Indexed, Ap: Expr, ii: MultiIndex) -> Expr | BaseForm:
"""Apply to an indexed."""
# Reuse if untouched
if Ap is o.ufl_operands[0]:
return o
r = len(Ap.ufl_shape) - len(ii)
if r:
kk = indices(r)
op = Indexed(Ap, MultiIndex(ii.indices() + kk))
op = as_tensor(op, kk)
else:
op = Indexed(Ap, ii)
return op


def apply_coefficient_derivatives(expression):
"""Apply coefficient derivatives to an expression."""
dag_traverser = CoefficientDerivativeRuleDispatcher()
return map_integrands(dag_traverser, expression)
6 changes: 4 additions & 2 deletions ufl/algorithms/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def replace(e, mapping):
# is not attractive), or make replace lazy too.
if has_exact_type(e, CoefficientDerivative):
# Hack to avoid circular dependencies
from ufl.algorithms.ad import expand_derivatives
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
from ufl.algorithms.apply_derivatives import apply_coefficient_derivatives

e = expand_derivatives(e)
e = apply_algebra_lowering(e)
e = apply_coefficient_derivatives(e)

return map_integrand_dags(Replacer(mapping2), e)