Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions test/test_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from utils import LagrangeElement

from ufl import (
Adjoint,
Argument,
Coefficient,
Cofunction,
Form,
FormProduct,
FormSum,
FunctionSpace,
Matrix,
Mesh,
SpatialCoordinate,
TestFunction,
Expand All @@ -19,6 +23,7 @@
grad,
inner,
nabla_grad,
replace,
triangle,
)
from ufl.form import BaseForm
Expand Down Expand Up @@ -203,3 +208,151 @@ def test_formsum(mass):
assert f.weights()[0] == -1
assert isinstance(df, FormSum)
assert df.weights()[0] == -9


def test_form_product_constructor_and_arguments(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
g = Coefficient(V)
h = Coefficient(V)

Lf = f * v * dx
Lg = g * v * dx
Lh = h * v * dx

product = FormProduct(Lf, Lg)
assert isinstance(product, BaseForm)
assert product.factors() == (Lf, Lg)
assert product.ufl_operands == (Lf, Lg)
assert product.factor_arguments() == (Lf.arguments(), Lg.arguments())

arguments = product.arguments()
assert tuple(argument.number() for argument in arguments) == (0, 1)
assert tuple(argument.part() for argument in arguments) == (None, None)
assert tuple(argument.ufl_function_space() for argument in arguments) == (V, V)
assert Lg.arguments()[0].number() == 0

assert product.coefficients() == (f, g)
assert product.ufl_domains() == (domain,)

nested = FormProduct(Lf, FormProduct(Lg, Lh))
assert nested.factors() == (Lf, Lg, Lh)
assert tuple(argument.number() for argument in nested.arguments()) == (0, 1, 2)


def test_form_product_of_one_factor_simplifies(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
L = f * v * dx

assert FormProduct(L) is L


def test_form_product_rejects_invalid_inputs(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
L = f * v * dx

with pytest.raises(ValueError):
FormProduct()
with pytest.raises(TypeError):
FormProduct(1)
with pytest.raises(TypeError):
FormProduct(L, 1)


def test_form_product_is_explicit_not_mul_overload(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
g = Coefficient(V)
Lf = f * v * dx
Lg = g * v * dx

with pytest.raises(TypeError):
Lf * Lg


def test_adjoint_form_product_reverses_adjoint_factors(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
A = Matrix(V, V)
B = Matrix(V, V)
C = Matrix(V, V)

product = FormProduct(A, B, C)
adjoint_product = Adjoint(product)

assert isinstance(adjoint_product, FormProduct)
assert tuple(factor.form() for factor in adjoint_product.factors()) == (C, B, A)
assert adjoint_product.factors() == (Adjoint(C), Adjoint(B), Adjoint(A))


def test_adjoint_form_product_leaves_rank_zero_and_one_factors_unadjointed(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
functional = f * dx
linear = f * v * dx
A = Matrix(V, V)

product = FormProduct(functional, linear, A)
adjoint_product = Adjoint(product)

assert isinstance(adjoint_product, FormProduct)
assert adjoint_product.factors() == (Adjoint(A), linear, functional)
assert Adjoint(FormProduct(functional, linear)).factors() == (linear, functional)


def test_form_product_replace(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
g = Coefficient(V)

Lf = f * v * dx
Lg = g * v * dx
product = FormProduct(Lf, Lg)
replaced = replace(product, {f: g})

assert isinstance(replaced, FormProduct)
assert bool(replaced.factors()[0] == Lg)
assert bool(replaced.factors()[1] == Lg)
assert tuple(argument.number() for argument in replaced.arguments()) == (0, 1)


def test_form_product_derivative_product_rule(domain):
element = LagrangeElement(triangle, 1)
V = FunctionSpace(domain, element)
v = TestFunction(V)
f = Coefficient(V)
direction = Argument(V, 1)

L = f * v * dx
product = FormProduct(L, L)
dproduct = derivative(product, f, direction)

assert isinstance(dproduct, FormSum)
assert len(dproduct.components()) == 2
assert all(isinstance(component, FormProduct) for component in dproduct.components())

dL = derivative(L, f, direction)
expected_first = FormProduct(dL, L)
expected_second = FormProduct(L, dL)
assert bool(dproduct.components()[0] == expected_first)
assert bool(dproduct.components()[1] == expected_second)


def test_form_product_exported_from_classes():
from ufl.classes import FormProduct as ClassesFormProduct

assert ClassesFormProduct is FormProduct
3 changes: 2 additions & 1 deletion ufl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@
from ufl.core.multiindex import Index, indices
from ufl.domain import AbstractDomain, Mesh, MeshSequence, MeshView
from ufl.finiteelement import AbstractFiniteElement
from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm
from ufl.form import BaseForm, Form, FormProduct, FormSum, ZeroBaseForm
from ufl.formoperators import (
action,
adjoint,
Expand Down Expand Up @@ -491,6 +491,7 @@
"FacetArea",
"FacetNormal",
"Form",
"FormProduct",
"FormSum",
"FunctionSpace",
"H1Curl",
Expand Down
10 changes: 9 additions & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ufl.argument import Coargument
from ufl.core.ufl_type import ufl_type
from ufl.form import BaseForm, FormSum, ZeroBaseForm
from ufl.form import BaseForm, FormProduct, FormSum, ZeroBaseForm

# --- The Adjoint class represents the adjoint of a numerical object that
# needs to be computed at assembly time ---
Expand Down Expand Up @@ -50,6 +50,14 @@ def __new__(cls, *args, **kw):
elif isinstance(form, FormSum):
# Adjoint distributes over sums
return FormSum(*((Adjoint(c), w) for c, w in zip(form.components(), form.weights())))
elif isinstance(form, FormProduct):
# Reverse product order and take the adjoint of rank-2 factors.
return FormProduct(
*(
factor if len(factor.arguments()) < 2 else Adjoint(factor)
for factor in reversed(form.factors())
)
)
elif isinstance(form, Coargument):
# The adjoint of a coargument `c: V* -> V*` is the identity
# matrix mapping from V to V (i.e. V x V* -> R).
Expand Down
9 changes: 8 additions & 1 deletion ufl/algorithms/map_integrands.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ufl.constantvalue import Zero
from ufl.core.expr import Expr
from ufl.corealg.map_dag import map_expr_dag
from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm
from ufl.form import BaseForm, Form, FormProduct, FormSum, ZeroBaseForm
from ufl.integral import Integral


Expand Down Expand Up @@ -69,6 +69,13 @@ def map_integrands(function, form, only_integral_type=None):
right = map_integrands(function, form._right, only_integral_type)
# Zeros are caught inside `Action.__new__`
return Action(left, right)
elif isinstance(form, FormProduct):
factors = tuple(
map_integrands(function, factor, only_integral_type) for factor in form.factors()
)
if any(factor == 0 for factor in factors):
return ZeroBaseForm(form.arguments())
return FormProduct(*factors)
elif isinstance(form, ZeroBaseForm):
arguments = tuple(
map_integrands(function, arg, only_integral_type) for arg in form._arguments
Expand Down
5 changes: 3 additions & 2 deletions ufl/algorithms/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ufl.action import Action
from ufl.adjoint import Adjoint
from ufl.core.expr import Expr
from ufl.form import BaseForm, Form, FormSum
from ufl.form import BaseForm, Form, FormProduct, FormSum
from ufl.integral import Integral


Expand All @@ -23,14 +23,15 @@ def iter_expressions(a):
- a is an Integral: the integrand expression of a
- a is a Form: all integrand expressions of all integrals
- a is a FormSum: the components of a
- a is a FormProduct: the factors of a
- a is an Action: the left and right component of a
- a is an Adjoint: the underlying form of a
"""
if isinstance(a, Form):
return (itg.integrand() for itg in a.integrals())
elif isinstance(a, Integral):
return (a.integrand(),)
elif isinstance(a, FormSum | Adjoint | Action):
elif isinstance(a, FormSum | FormProduct | Adjoint | Action):
return tuple(e for op in a.ufl_operands for e in iter_expressions(op))
elif isinstance(a, Expr | BaseForm):
return (a,)
Expand Down
3 changes: 2 additions & 1 deletion ufl/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from ufl.equation import Equation
from ufl.exprcontainers import ExprList, ExprMapping
from ufl.finiteelement import AbstractFiniteElement
from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm
from ufl.form import BaseForm, Form, FormProduct, FormSum, ZeroBaseForm
from ufl.functionspace import (
AbstractFunctionSpace,
DualSpace,
Expand Down Expand Up @@ -323,6 +323,7 @@
"Form",
"Form",
"FormArgument",
"FormProduct",
"FormSum",
"FunctionSpace",
"GeometricCellQuantity",
Expand Down
Loading