Skip to content

Commit

Permalink
Move the optimisation a level lower
Browse files Browse the repository at this point in the history
  • Loading branch information
sv2518 committed Nov 9, 2021
1 parent 4adcf93 commit 7822a43
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 41 deletions.
3 changes: 3 additions & 0 deletions firedrake/slate/slac/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import repeat
from firedrake.slate.slate import *
from collections import namedtuple
from firedrake.ufl_expr import adjoint

""" ActionBag class
:arg coeff: what we contract with.
Expand Down Expand Up @@ -235,6 +236,8 @@ def _drop_double_transpose_transpose(expr, self):
if isinstance(child, Transpose):
grandchild, = child.children
return self(grandchild)
elif child.terminal and child.rank > 1:
return Tensor(adjoint(child.form))
else:
return type(expr)(*map(self, expr.children))

Expand Down
6 changes: 0 additions & 6 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from collections import OrderedDict

from ufl import Coefficient, Constant
from firedrake.ufl_expr import adjoint

from firedrake.function import Function
from firedrake.utils import cached_property
Expand Down Expand Up @@ -1025,11 +1024,6 @@ def _output_string(self, prec=None):
class Transpose(UnaryOp):
"""An abstract Slate class representing the transpose of a tensor."""

def __new__(cls, A):
if A.terminal and A.rank > 1:
return Tensor(adjoint(A.form))
return super().__new__(cls)

@cached_property
def arg_function_spaces(self):
"""Returns a tuple of function spaces that the tensor
Expand Down
35 changes: 0 additions & 35 deletions tests/slate/test_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,38 +116,3 @@ def test_local_solve(decomp):
x = assemble(A.solve(b, decomposition=decomp))

assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)


def test_transpose():
mesh = UnitSquareMesh(2, 2, True)
p1 = VectorElement("CG", triangle, 2)
p0 = FiniteElement("CG", triangle, 1)
p1p0 = MixedElement([p1, p0])

velo = FunctionSpace(mesh, p1)
pres = FunctionSpace(mesh, p0)
mixed = FunctionSpace(mesh, p1p0)

w = Function(mixed)
x = SpatialCoordinate(mesh)
velo = Function(velo).project(as_vector([10*sin(pi*x[0]), 0]))
w.sub(0).assign(velo)
pres = Function(pres).assign(10.)
w.sub(1).assign(pres)

dg = FunctionSpace(mesh, "DG", 2)
T = TrialFunction(dg)
v = TestFunction(dg)

n = FacetNormal(mesh)
u = split(w)[0]
un = abs(dot(u('+'), n('+')))
jump_v = v('+')*n('+') + v('-')*n('-')
jump_T = T('+')*n('+') + T('-')*n('-')
x, y = SpatialCoordinate(mesh)

T = Tensor(-dot(u*T, grad(v))*dx + (dot(u('+'), jump_v)*avg(T))*dS + dot(v, dot(u, n)*T)*ds + 0.5*un*dot(jump_T, jump_v)*dS)
assert T != T.T
assert isinstance(T.T, Tensor)
assert np.allclose(assemble(T.T).M.values,
assemble(adjoint(T.form)).M.values)
7 changes: 7 additions & 0 deletions tests/slate/test_optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,13 @@ def test_drop_transposes(TC_non_symm):
compare_vector_expressions(expressions)
compare_slate_tensors(expressions, opt_expressions)

assert A != A.T
from firedrake.slate.slac.optimise import optimise
T_opt = optimise(A.T, {"optimise": True})
assert isinstance(T_opt, Tensor)
assert np.allclose(assemble(T_opt.T).M.values,
assemble(adjoint(T_opt.form)).M.values)


#######################################
# Test diagonal optimisation pass
Expand Down

0 comments on commit 7822a43

Please sign in to comment.