Skip to content

Commit

Permalink
move replace_to_be_restricted to ufl
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Apr 9, 2024
1 parent ef4754e commit 08a7b80
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 47 deletions.
4 changes: 1 addition & 3 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,10 @@ def compile_integral(integral_types, integrals, integral_data, form_data, domain
# so we should attach the constants to integral data instead
builder.set_constants(form_data.constants)
ctx = builder.create_context()
# TODO: Move relevant part of the code to UFL and remove this flag.
have_multiple_domains = len(all_meshes) > 1
for integral in integrals:
params = parameters.copy()
params.update(integral.metadata()) # integral metadata overrides
integrand_exprs = builder.compile_integrand(integral.integrand(), params, ctx, have_multiple_domains)
integrand_exprs = builder.compile_integrand(integral.integrand(), params, ctx)
integral_exprs = builder.construct_integrals(integrand_exprs, params)
builder.stash_integrals(integral_exprs, params, ctx)
return builder.construct_kernel(kernel_name, ctx, log)
Expand Down
6 changes: 1 addition & 5 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def register_requirements(self, ir):
class KernelBuilderMixin(object):
"""Mixin for KernelBuilder classes."""

def compile_integrand(self, integrand, params, ctx, have_multiple_domains):
def compile_integrand(self, integrand, params, ctx):
"""Compile UFL integrand.
:arg integrand: UFL integrand.
Expand All @@ -135,10 +135,6 @@ def compile_integrand(self, integrand, params, ctx, have_multiple_domains):
See :meth:`create_context` for typical calling sequence.
"""
# Remove '?' restrictions
if have_multiple_domains:
# TODO: This should happen in UFL.
integrand = ufl_utils.replace_to_be_restricted_restrictions(integrand, self._domain_integral_type_map)
# Compile: ufl -> gem
info = self.integral_data_info
functions = list(info.arguments) + [self.coordinate(info.domain)] + list(info.coefficients)
Expand Down
39 changes: 0 additions & 39 deletions tsfc/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,42 +590,3 @@ def remove_indices(o):
else:
rule = IndexRemover()
return map_expr_dag(rule, o)


# Move this to UFL.
class ToBeRestrectedReplacer(MultiFunction, ModifiedTerminalMixin):

def __init__(self, domain_integral_type_map):
MultiFunction.__init__(self)
self.domain_integral_type_map = domain_integral_type_map

expr = MultiFunction.reuse_if_untouched

def modified_terminal(self, o):
mt = analyse_modified_terminal(o)
t = mt.terminal
r = mt.restriction
rv = mt.reference_value
ld = mt.local_derivatives
if r != '?':
return o
domain = extract_unique_domain(t)
if domain not in self.domain_integral_type_map:
raise RuntimeError(f"Integral type on {domain} not known")
integral_type = self.domain_integral_type_map[domain]
if integral_type == "cell":
mmt = ModifiedTerminal(o, t, ld, None, rv)
elif integral_type == "exterior_facet":
mmt = ModifiedTerminal(o, t, ld, '|', rv)
elif integral_type == "interial_facet":
mmt = ModifiedTerminal(o, t, ld, '+', rv)
else:
raise RuntimeError(f"Integral type {integral_type} not handled")
return construct_modified_terminal(mmt, t)


# Move this to UFL.
def replace_to_be_restricted_restrictions(integrand, domain_integral_type_map):
integrand = remove_indices(integrand)
rule = ToBeRestrectedReplacer(domain_integral_type_map)
return map_expr_dag(rule, integrand)

0 comments on commit 08a7b80

Please sign in to comment.