diff --git a/tests/test_create_fiat_element.py b/tests/test_create_fiat_element.py index 28356b00..f8a7d6ef 100644 --- a/tests/test_create_fiat_element.py +++ b/tests/test_create_fiat_element.py @@ -1,7 +1,7 @@ import pytest import FIAT -from FIAT.discontinuous_lagrange import HigherOrderDiscontinuousLagrange as FIAT_DiscontinuousLagrange +from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange import ufl import finat.ufl diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index b0c813ed..320b3e32 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -19,14 +19,13 @@ # You should have received a copy of the GNU Lesser General Public License # along with FFC. If not, see . -from functools import singledispatch, partial import weakref +from functools import partial, singledispatch import FIAT import finat -import ufl import finat.ufl - +import ufl __all__ = ("as_fiat_cell", "create_base_element", "create_element", "supported_elements") @@ -52,6 +51,8 @@ "Hermite": finat.Hermite, "Kong-Mulder-Veldhuizen": finat.KongMulderVeldhuizen, "Argyris": finat.Argyris, + "Hsieh-Clough-Tocher": finat.HsiehCloughTocher, + "Reduced-Hsieh-Clough-Tocher": finat.ReducedHsiehCloughTocher, "Mardal-Tai-Winther": finat.MardalTaiWinther, "Morley": finat.Morley, "Bell": finat.Bell, @@ -144,12 +145,10 @@ def convert_finiteelement(element, **kwargs): kind = 'spectral' # default variant if element.family() == "Lagrange": - if kind == 'equispaced': - lmbda = finat.Lagrange - elif kind == 'spectral': + if kind == 'spectral': lmbda = finat.GaussLobattoLegendre - elif kind == 'integral': - lmbda = finat.IntegratedLegendre + elif kind.startswith('integral'): + lmbda = partial(finat.IntegratedLegendre, variant=kind) elif kind in ['fdm', 'fdm_ipdg'] and is_interval: lmbda = finat.FDMLagrange elif kind == 'fdm_quadrature' and is_interval: @@ -167,17 +166,16 @@ def convert_finiteelement(element, **kwargs): deps = {"shift_axes", "restriction"} return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps else: - raise ValueError("Variant %r not supported on %s" % (kind, element.cell)) + # Let FIAT handle the general case + lmbda = partial(finat.Lagrange, variant=kind) elif element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)", "Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)"}: lmbda = partial(lmbda, variant=element.variant()) elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]: - if kind == 'equispaced': - lmbda = finat.DiscontinuousLagrange - elif kind == 'spectral': + if kind == 'spectral': lmbda = finat.GaussLegendre - elif kind == 'integral': - lmbda = finat.Legendre + elif kind.startswith('integral'): + lmbda = partial(finat.Legendre, variant=kind) elif kind in ['fdm', 'fdm_quadrature'] and is_interval: lmbda = finat.FDMDiscontinuousLagrange elif kind == 'fdm_ipdg' and is_interval: @@ -191,7 +189,8 @@ def convert_finiteelement(element, **kwargs): deps = {"shift_axes", "restriction"} return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps else: - raise ValueError("Variant %r not supported on %s" % (kind, element.cell)) + # Let FIAT handle the general case + lmbda = partial(finat.DiscontinuousLagrange, variant=kind) elif element.family() == ["DPC", "DPC L2"]: if element.cell.geometric_dimension() == 2: element = element.reconstruct(cell=ufl.cell.hypercube(2)) diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index 67f7dac8..18fd363f 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -1,29 +1,25 @@ import collections -import string import operator +import string from functools import reduce from itertools import chain, product +import gem +import gem.impero_utils as impero_utils import numpy -from numpy import asarray - -from ufl.utils.sequences import max_degree - from FIAT.reference_element import TensorProductCell - +from finat.cell_tools import max_complex from finat.quadrature import AbstractQuadratureRule, make_quadrature - -import gem - from gem.node import traversal +from gem.optimise import constant_fold_zero +from gem.optimise import remove_componenttensors as prune from gem.utils import cached_property -import gem.impero_utils as impero_utils -from gem.optimise import remove_componenttensors as prune, constant_fold_zero - +from numpy import asarray from tsfc import fem, ufl_utils -from tsfc.kernel_interface import KernelInterface from tsfc.finatinterface import as_fiat_cell, create_element +from tsfc.kernel_interface import KernelInterface from tsfc.logging import logger +from ufl.utils.sequences import max_degree class KernelBuilderBase(KernelInterface): @@ -301,22 +297,26 @@ def set_quad_rule(params, cell, integral_type, functions): quadrature_degree = params["quadrature_degree"] except KeyError: quadrature_degree = params["estimated_polynomial_degree"] - function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions] + function_degrees = [f.ufl_function_space().ufl_element().degree() + for f in functions] if all((asarray(quadrature_degree) > 10 * asarray(degree)).all() for degree in function_degrees): logger.warning("Estimated quadrature degree %s more " "than tenfold greater than any " "argument/coefficient degree (max %s)", quadrature_degree, max_degree(function_degrees)) - if params.get("quadrature_rule") == "default": - del params["quadrature_rule"] - try: - quad_rule = params["quadrature_rule"] - except KeyError: + quad_rule = params.get("quadrature_rule", "default") + if isinstance(quad_rule, str): + scheme = quad_rule fiat_cell = as_fiat_cell(cell) + finat_elements = set(create_element(f.ufl_element()) for f in functions + if f.ufl_element().family() != "Real") + fiat_cells = [fiat_cell] + [finat_el.complex for finat_el in finat_elements] + fiat_cell = max_complex(fiat_cells) + integration_dim, _ = lower_integral_type(fiat_cell, integral_type) - integration_cell = fiat_cell.construct_subelement(integration_dim) - quad_rule = make_quadrature(integration_cell, quadrature_degree) + integration_cell = fiat_cell.construct_subcomplex(integration_dim) + quad_rule = make_quadrature(integration_cell, quadrature_degree, scheme=scheme) params["quadrature_rule"] = quad_rule if not isinstance(quad_rule, AbstractQuadratureRule):