Skip to content

Commit

Permalink
Slate optimiser: Introduction of a Slate optimisation which does Tran…
Browse files Browse the repository at this point in the history
…spose(Tensor(form) -> Tensor(adjoint(form)). This optimisation is introduced at the Slate optimiser level but essentially results in inlining transposes on tensors into the local assembly kernels generated by TSFC
  • Loading branch information
sv2518 committed Nov 10, 2021
1 parent 3033537 commit 3dc3e7d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
6 changes: 6 additions & 0 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import OrderedDict

from ufl import Coefficient, Constant
from firedrake.ufl_expr import adjoint

from firedrake.function import Function
from firedrake.utils import cached_property
Expand Down Expand Up @@ -1024,6 +1025,11 @@ def _output_string(self, prec=None):
class Transpose(UnaryOp):
"""An abstract Slate class representing the transpose of a tensor."""

def __new__(cls, A):
if A.terminal and A.rank > 1:
return Tensor(adjoint(A.form))
return super().__new__(cls)

@cached_property
def arg_function_spaces(self):
"""Returns a tuple of function spaces that the tensor
Expand Down
35 changes: 35 additions & 0 deletions tests/slate/test_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,38 @@ def test_local_solve(decomp):
x = assemble(A.solve(b, decomposition=decomp))

assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)


def test_transpose():
mesh = UnitSquareMesh(2, 2, True)
p1 = VectorElement("CG", triangle, 2)
p0 = FiniteElement("CG", triangle, 1)
p1p0 = MixedElement([p1, p0])

velo = FunctionSpace(mesh, p1)
pres = FunctionSpace(mesh, p0)
mixed = FunctionSpace(mesh, p1p0)

w = Function(mixed)
x = SpatialCoordinate(mesh)
velo = Function(velo).project(as_vector([10*sin(pi*x[0]), 0]))
w.sub(0).assign(velo)
pres = Function(pres).assign(10.)
w.sub(1).assign(pres)

dg = FunctionSpace(mesh, "DG", 2)
T = TrialFunction(dg)
v = TestFunction(dg)

n = FacetNormal(mesh)
u = split(w)[0]
un = abs(dot(u('+'), n('+')))
jump_v = v('+')*n('+') + v('-')*n('-')
jump_T = T('+')*n('+') + T('-')*n('-')
x, y = SpatialCoordinate(mesh)

T = Tensor(-dot(u*T, grad(v))*dx + (dot(u('+'), jump_v)*avg(T))*dS + dot(v, dot(u, n)*T)*ds + 0.5*un*dot(jump_T, jump_v)*dS)
assert T != T.T
assert isinstance(T.T, Tensor)
assert np.allclose(assemble(T.T).M.values,
assemble(adjoint(T.form)).M.values)

0 comments on commit 3dc3e7d

Please sign in to comment.