Skip to content

Commit

Permalink
Refactor variants
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 24, 2024
1 parent 283cc54 commit d7f064c
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# along with FFC. If not, see <http://www.gnu.org/licenses/>.

import weakref
from functools import partial, singledispatch
from functools import singledispatch

import FIAT
import finat
Expand Down Expand Up @@ -124,6 +124,23 @@ def convert(element, **kwargs):
raise ValueError("Unsupported element type %s" % type(element))


cg_interval_variants = {
"fdm": finat.FDMLagrange,
"fdm_ipdg": finat.FDMLagrange,
"fdm_quadrature": finat.FDMQuadrature,
"fdm_broken": finat.FDMBrokenH1,
"fdm_hermite": finat.FDMHermite,
}


dg_interval_variants = {
"fdm": finat.FDMDiscontinuousLagrange,
"fdm_quadrature": finat.FDMDiscontinuousLagrange,
"fdm_ipdg": lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args)),
"fdm_broken": finat.FDMBrokenL2,
}


# Base finite elements first
@convert.register(finat.ufl.FiniteElement)
def convert_finiteelement(element, **kwargs):
Expand Down Expand Up @@ -152,30 +169,19 @@ def convert_finiteelement(element, **kwargs):
finat_elem, deps = _create_element(element, **kwargs)
return finat.FlattenedDimensions(finat_elem), deps

kw = {}
kind = element.variant()
if kind is None:
kind = 'spectral' # default variant
is_interval = element.cell.cellname() == 'interval'

if element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)",
"Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)",
"Argyris"}:
lmbda = partial(lmbda, variant=element.variant())
elif element.family() == "Lagrange":
if element.family() == "Lagrange":
if kind == 'spectral':
lmbda = finat.GaussLobattoLegendre
elif kind.startswith('integral'):
lmbda = partial(finat.IntegratedLegendre, variant=kind)
elif kind in ['fdm', 'fdm_ipdg'] and is_interval:
lmbda = finat.FDMLagrange
elif kind == 'fdm_quadrature' and is_interval:
lmbda = finat.FDMQuadrature
elif kind == 'fdm_broken' and is_interval:
lmbda = finat.FDMBrokenH1
elif kind == 'fdm_hermite' and is_interval:
lmbda = finat.FDMHermite
elif kind in ['demkowicz', 'fdm']:
lmbda = partial(finat.IntegratedLegendre, variant=kind)
elif element.cell.cellname() == "interval" and kind in cg_interval_variants:
lmbda = cg_interval_variants[kind]
elif kind.startswith('integral') or kind in ['demkowicz', 'fdm']:
lmbda = finat.IntegratedLegendre
kw["variant"] = kind
elif kind in ['mgd', 'feec', 'qb', 'mse']:
degree = element.degree()
shift_axes = kwargs["shift_axes"]
Expand All @@ -184,20 +190,17 @@ def convert_finiteelement(element, **kwargs):
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps
else:
# Let FIAT handle the general case
lmbda = partial(finat.Lagrange, variant=kind)
lmbda = finat.Lagrange
kw["variant"] = kind

elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]:
if kind == 'spectral':
lmbda = finat.GaussLegendre
elif kind.startswith('integral'):
lmbda = partial(finat.Legendre, variant=kind)
elif kind in ['fdm', 'fdm_quadrature'] and is_interval:
lmbda = finat.FDMDiscontinuousLagrange
elif kind == 'fdm_ipdg' and is_interval:
lmbda = lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args))
elif kind in 'fdm_broken' and is_interval:
lmbda = finat.FDMBrokenL2
elif kind in ['demkowicz', 'fdm']:
lmbda = partial(finat.Legendre, variant=kind)
elif element.cell.cellname() == "interval" and kind in dg_interval_variants:
lmbda = dg_interval_variants[kind]
elif kind.startswith('integral') or kind in ['demkowicz', 'fdm']:
lmbda = finat.Legendre
kw["variant"] = kind
elif kind in ['mgd', 'feec', 'qb', 'mse']:
degree = element.degree()
shift_axes = kwargs["shift_axes"]
Expand All @@ -206,13 +209,13 @@ def convert_finiteelement(element, **kwargs):
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps
else:
# Let FIAT handle the general case
lmbda = partial(finat.DiscontinuousLagrange, variant=kind)
elif element.family() == ["DPC", "DPC L2", "S"]:
dim = element.cell.geometric_dimension()
if dim > 1:
element = element.reconstruct(cell=ufl.cell.hypercube(dim))
lmbda = finat.DiscontinuousLagrange
kw["variant"] = kind

elif element.variant() is not None:
kw["variant"] = element.variant()

return lmbda(cell, element.degree()), set()
return lmbda(cell, element.degree(), **kw), set()


# Element modifiers and compound element types
Expand Down

0 comments on commit d7f064c

Please sign in to comment.