Skip to content

Commit

Permalink
Merge pull request #310 from firedrakeproject/rckirby/feature/macro
Browse files Browse the repository at this point in the history
Support FIAT macroelements
  • Loading branch information
rckirby authored May 1, 2024
2 parents 90c20c5 + 916e773 commit 021589c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 37 deletions.
2 changes: 1 addition & 1 deletion tests/test_create_fiat_element.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

import FIAT
from FIAT.discontinuous_lagrange import HigherOrderDiscontinuousLagrange as FIAT_DiscontinuousLagrange
from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange

import ufl
import finat.ufl
Expand Down
29 changes: 14 additions & 15 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
# You should have received a copy of the GNU Lesser General Public License
# along with FFC. If not, see <http://www.gnu.org/licenses/>.

from functools import singledispatch, partial
import weakref
from functools import partial, singledispatch

import FIAT
import finat
import ufl
import finat.ufl

import ufl

__all__ = ("as_fiat_cell", "create_base_element",
"create_element", "supported_elements")
Expand All @@ -52,6 +51,8 @@
"Hermite": finat.Hermite,
"Kong-Mulder-Veldhuizen": finat.KongMulderVeldhuizen,
"Argyris": finat.Argyris,
"Hsieh-Clough-Tocher": finat.HsiehCloughTocher,
"Reduced-Hsieh-Clough-Tocher": finat.ReducedHsiehCloughTocher,
"Mardal-Tai-Winther": finat.MardalTaiWinther,
"Morley": finat.Morley,
"Bell": finat.Bell,
Expand Down Expand Up @@ -144,12 +145,10 @@ def convert_finiteelement(element, **kwargs):
kind = 'spectral' # default variant

if element.family() == "Lagrange":
if kind == 'equispaced':
lmbda = finat.Lagrange
elif kind == 'spectral':
if kind == 'spectral':
lmbda = finat.GaussLobattoLegendre
elif kind == 'integral':
lmbda = finat.IntegratedLegendre
elif kind.startswith('integral'):
lmbda = partial(finat.IntegratedLegendre, variant=kind)
elif kind in ['fdm', 'fdm_ipdg'] and is_interval:
lmbda = finat.FDMLagrange
elif kind == 'fdm_quadrature' and is_interval:
Expand All @@ -167,17 +166,16 @@ def convert_finiteelement(element, **kwargs):
deps = {"shift_axes", "restriction"}
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell))
# Let FIAT handle the general case
lmbda = partial(finat.Lagrange, variant=kind)
elif element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)",
"Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)"}:
lmbda = partial(lmbda, variant=element.variant())
elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]:
if kind == 'equispaced':
lmbda = finat.DiscontinuousLagrange
elif kind == 'spectral':
if kind == 'spectral':
lmbda = finat.GaussLegendre
elif kind == 'integral':
lmbda = finat.Legendre
elif kind.startswith('integral'):
lmbda = partial(finat.Legendre, variant=kind)
elif kind in ['fdm', 'fdm_quadrature'] and is_interval:
lmbda = finat.FDMDiscontinuousLagrange
elif kind == 'fdm_ipdg' and is_interval:
Expand All @@ -191,7 +189,8 @@ def convert_finiteelement(element, **kwargs):
deps = {"shift_axes", "restriction"}
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell))
# Let FIAT handle the general case
lmbda = partial(finat.DiscontinuousLagrange, variant=kind)
elif element.family() == ["DPC", "DPC L2"]:
if element.cell.geometric_dimension() == 2:
element = element.reconstruct(cell=ufl.cell.hypercube(2))
Expand Down
42 changes: 21 additions & 21 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
import collections
import string
import operator
import string
from functools import reduce
from itertools import chain, product

import gem
import gem.impero_utils as impero_utils
import numpy
from numpy import asarray

from ufl.utils.sequences import max_degree

from FIAT.reference_element import TensorProductCell

from finat.cell_tools import max_complex
from finat.quadrature import AbstractQuadratureRule, make_quadrature

import gem

from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from gem.utils import cached_property
import gem.impero_utils as impero_utils
from gem.optimise import remove_componenttensors as prune, constant_fold_zero

from numpy import asarray
from tsfc import fem, ufl_utils
from tsfc.kernel_interface import KernelInterface
from tsfc.finatinterface import as_fiat_cell, create_element
from tsfc.kernel_interface import KernelInterface
from tsfc.logging import logger
from ufl.utils.sequences import max_degree


class KernelBuilderBase(KernelInterface):
Expand Down Expand Up @@ -301,22 +297,26 @@ def set_quad_rule(params, cell, integral_type, functions):
quadrature_degree = params["quadrature_degree"]
except KeyError:
quadrature_degree = params["estimated_polynomial_degree"]
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
function_degrees = [f.ufl_function_space().ufl_element().degree()
for f in functions]
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
for degree in function_degrees):
logger.warning("Estimated quadrature degree %s more "
"than tenfold greater than any "
"argument/coefficient degree (max %s)",
quadrature_degree, max_degree(function_degrees))
if params.get("quadrature_rule") == "default":
del params["quadrature_rule"]
try:
quad_rule = params["quadrature_rule"]
except KeyError:
quad_rule = params.get("quadrature_rule", "default")
if isinstance(quad_rule, str):
scheme = quad_rule
fiat_cell = as_fiat_cell(cell)
finat_elements = set(create_element(f.ufl_element()) for f in functions
if f.ufl_element().family() != "Real")
fiat_cells = [fiat_cell] + [finat_el.complex for finat_el in finat_elements]
fiat_cell = max_complex(fiat_cells)

integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
integration_cell = fiat_cell.construct_subelement(integration_dim)
quad_rule = make_quadrature(integration_cell, quadrature_degree)
integration_cell = fiat_cell.construct_subcomplex(integration_dim)
quad_rule = make_quadrature(integration_cell, quadrature_degree, scheme=scheme)
params["quadrature_rule"] = quad_rule

if not isinstance(quad_rule, AbstractQuadratureRule):
Expand Down

0 comments on commit 021589c

Please sign in to comment.