Skip to content

Commit f473388

Browse files
pbrubeckksagiyam
andauthored
FEniCS-style bcs (#3995)
* FEniCS-style bcs * Delay Dirichlet Lifting for LVP, MG, and Fieldsplit * LinearSolver: support pre_apply_bcs * Update firedrake/assemble.py Co-authored-by: ksagiyam <46749170+ksagiyam@users.noreply.github.com> --------- Co-authored-by: ksagiyam <46749170+ksagiyam@users.noreply.github.com>
1 parent 3331ed1 commit f473388

File tree

11 files changed

+277
-146
lines changed

11 files changed

+277
-146
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,10 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):
197197

198198
def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
199199
dJdu_copy = dJdu.copy()
200-
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
200+
# Homogenize and apply boundary conditions on adj_dFdu.
201201
bcs = self._homogenize_bcs()
202202
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)
203203

204-
for bc in bcs:
205-
bc.zero(dJdu)
206-
207204
adj_sol = firedrake.Function(self.function_space)
208205
firedrake.solve(
209206
dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs
@@ -526,18 +523,11 @@ def _forward_solve(self, lhs, rhs, func, bcs):
526523
return func
527524

528525
def _assembled_solve(self, lhs, rhs, func, bcs, **kwargs):
529-
rhs_func = rhs.riesz_representation(riesz_map="l2")
530-
for bc in bcs:
531-
bc.apply(rhs_func)
532-
rhs.assign(rhs_func.riesz_representation(riesz_map="l2"))
533526
firedrake.solve(lhs, func, rhs, **kwargs)
534527
return func
535528

536529
def recompute_component(self, inputs, block_variable, idx, prepared):
537-
lhs = prepared[0]
538-
rhs = prepared[1]
539-
func = prepared[2]
540-
bcs = prepared[3]
530+
lhs, rhs, func, bcs = prepared
541531
result = self._forward_solve(lhs, rhs, func, bcs)
542532
if isinstance(block_variable.checkpoint, firedrake.Function):
543533
result = block_variable.checkpoint.assign(result)

firedrake/assemble.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def assemble(expr, *args, **kwargs):
9595
`matrix.Matrix`.
9696
is_base_form_preprocessed : bool
9797
If `True`, skip preprocessing of the form.
98+
current_state : firedrake.function.Function or None
99+
If provided and ``zero_bc_nodes == False``, the boundary condition
100+
nodes of the output are set to the residual of the boundary conditions
101+
computed as ``current_state`` minus the boundary condition value.
98102
99103
Returns
100104
-------
@@ -130,16 +134,21 @@ def assemble(expr, *args, **kwargs):
130134
"""
131135
if args:
132136
raise RuntimeError(f"Got unexpected args: {args}")
133-
tensor = kwargs.pop("tensor", None)
134-
return get_assembler(expr, *args, **kwargs).assemble(tensor=tensor)
137+
138+
assemble_kwargs = {}
139+
for key in ("tensor", "current_state"):
140+
if key in kwargs:
141+
assemble_kwargs[key] = kwargs.pop(key, None)
142+
return get_assembler(expr, *args, **kwargs).assemble(**assemble_kwargs)
135143

136144

137145
def get_assembler(form, *args, **kwargs):
138146
"""Create an assembler.
139147
140148
Notes
141149
-----
142-
See `assemble` for descriptions of the parameters. ``tensor`` should not be passed to this function.
150+
See `assemble` for descriptions of the parameters. ``tensor`` and
151+
``current_state`` should not be passed to this function.
143152
144153
"""
145154
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
@@ -187,13 +196,15 @@ class ExprAssembler(object):
187196
def __init__(self, expr):
188197
self._expr = expr
189198

190-
def assemble(self, tensor=None):
199+
def assemble(self, tensor=None, current_state=None):
191200
"""Assemble the pointwise expression.
192201
193202
Parameters
194203
----------
195204
tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
196205
Output tensor.
206+
current_state : None
207+
Ignored by this class.
197208
198209
Returns
199210
-------
@@ -205,6 +216,7 @@ def assemble(self, tensor=None):
205216
from ufl.checks import is_scalar_constant_expression
206217

207218
assert tensor is None
219+
assert current_state is None
208220
expr = self._expr
209221
# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
210222
base_form_operators = extract_base_form_operators(expr)
@@ -274,13 +286,16 @@ def allocate(self):
274286
"""Allocate memory for the output tensor."""
275287

276288
@abc.abstractmethod
277-
def assemble(self, tensor=None):
289+
def assemble(self, tensor=None, current_state=None):
278290
"""Assemble the form.
279291
280292
Parameters
281293
----------
282294
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
283295
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
296+
current_state : firedrake.function.Function or None
297+
If provided, the boundary condition nodes are set to the boundary condition residual
298+
computed as ``current_state`` minus the boundary condition value.
284299
285300
Returns
286301
-------
@@ -358,13 +373,16 @@ def allocation_integral_types(self):
358373
else:
359374
return self._allocation_integral_types
360375

361-
def assemble(self, tensor=None):
376+
def assemble(self, tensor=None, current_state=None):
362377
"""Assemble the form.
363378
364379
Parameters
365380
----------
366381
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
367382
Output tensor to contain the result of assembly.
383+
current_state : firedrake.function.Function or None
384+
If provided, the boundary condition nodes are set to the boundary condition residual
385+
computed as ``current_state`` minus the boundary condition value.
368386
369387
Returns
370388
-------
@@ -389,7 +407,7 @@ def visitor(e, *operands):
389407
rank = len(self._form.arguments())
390408
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
391409
for bc in self._bcs:
392-
bc.zero(result)
410+
OneFormAssembler._apply_bc(self, result, bc, u=current_state)
393411

394412
if tensor:
395413
BaseFormAssembler.update_tensor(result, tensor)
@@ -968,13 +986,16 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
968986
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
969987
self._needs_zeroing = needs_zeroing
970988

971-
def assemble(self, tensor=None):
989+
def assemble(self, tensor=None, current_state=None):
972990
"""Assemble the form.
973991
974992
Parameters
975993
----------
976994
tensor : firedrake.cofunction.Cofunction or matrix.MatrixBase
977995
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
996+
current_state : firedrake.function.Function or None
997+
If provided, the boundary condition nodes are set to the boundary condition residual
998+
computed as ``current_state`` minus the boundary condition value.
978999
9791000
Returns
9801001
-------
@@ -998,12 +1019,12 @@ def assemble(self, tensor=None):
9981019
self.execute_parloops(tensor)
9991020

10001021
for bc in self._bcs:
1001-
self._apply_bc(tensor, bc)
1022+
self._apply_bc(tensor, bc, u=current_state)
10021023

10031024
return self.result(tensor)
10041025

10051026
@abc.abstractmethod
1006-
def _apply_bc(self, tensor, bc):
1027+
def _apply_bc(self, tensor, bc, u=None):
10071028
"""Apply boundary condition."""
10081029

10091030
@abc.abstractmethod
@@ -1138,7 +1159,7 @@ def allocate(self):
11381159
comm=self._form.ufl_domains()[0]._comm
11391160
)
11401161

1141-
def _apply_bc(self, tensor, bc):
1162+
def _apply_bc(self, tensor, bc, u=None):
11421163
pass
11431164

11441165
def _check_tensor(self, tensor):
@@ -1199,26 +1220,29 @@ def allocate(self):
11991220
else:
12001221
raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}")
12011222

1202-
def _apply_bc(self, tensor, bc):
1223+
def _apply_bc(self, tensor, bc, u=None):
12031224
# TODO Maybe this could be a singledispatchmethod?
12041225
if isinstance(bc, DirichletBC):
1205-
self._apply_dirichlet_bc(tensor, bc)
1226+
if self._diagonal:
1227+
bc.set(tensor, self._weight)
1228+
elif self._zero_bc_nodes:
1229+
bc.zero(tensor)
1230+
else:
1231+
# The residual belongs to a mixed space that is dual on the boundary nodes
1232+
# and primal on the interior nodes. Therefore, this is a type-safe operation.
1233+
r = tensor.riesz_representation("l2")
1234+
bc.apply(r, u=u)
12061235
elif isinstance(bc, EquationBCSplit):
12071236
bc.zero(tensor)
1208-
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
1209-
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
1237+
OneFormAssembler(bc.f, bcs=bc.bcs,
1238+
form_compiler_parameters=self._form_compiler_params,
1239+
needs_zeroing=False,
1240+
zero_bc_nodes=self._zero_bc_nodes,
1241+
diagonal=self._diagonal,
1242+
weight=self._weight).assemble(tensor=tensor, current_state=u)
12101243
else:
12111244
raise AssertionError
12121245

1213-
def _apply_dirichlet_bc(self, tensor, bc):
1214-
if self._diagonal:
1215-
bc.set(tensor, self._weight)
1216-
elif not self._zero_bc_nodes:
1217-
# NOTE this only works if tensor is a Function and not a Cofunction
1218-
bc.apply(tensor)
1219-
else:
1220-
bc.zero(tensor)
1221-
12221246
def _check_tensor(self, tensor):
12231247
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
12241248
raise ValueError("Form's argument does not match provided result tensor")
@@ -1430,7 +1454,8 @@ def _all_assemblers(self):
14301454
all_assemblers.extend(_assembler._all_assemblers)
14311455
return tuple(all_assemblers)
14321456

1433-
def _apply_bc(self, tensor, bc):
1457+
def _apply_bc(self, tensor, bc, u=None):
1458+
assert u is None
14341459
op2tensor = tensor.M
14351460
spaces = tuple(a.function_space() for a in tensor.a.arguments())
14361461
V = bc.function_space()
@@ -1534,7 +1559,7 @@ def allocate(self):
15341559
options_prefix=self._options_prefix,
15351560
appctx=self._appctx or {})
15361561

1537-
def assemble(self, tensor=None):
1562+
def assemble(self, tensor=None, current_state=None):
15381563
if tensor is None:
15391564
tensor = self.allocate()
15401565
else:

firedrake/bcs.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def extract_form(self, form_type):
462462
# DirichletBC is directly used in assembly.
463463
return self
464464

465-
def _as_nonlinear_variational_problem_arg(self):
465+
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
466466
return self
467467

468468

@@ -501,15 +501,16 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
501501
# linear
502502
if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):
503503
J = eq.lhs
504+
L = eq.rhs
504505
Jp = Jp or J
505-
if eq.rhs == 0:
506+
if L == 0 or L.empty():
506507
F = ufl_expr.action(J, u)
507508
else:
508-
if not isinstance(eq.rhs, (ufl.Form, slate.slate.TensorBase)):
509-
raise TypeError("Provided BC RHS is a '%s', not a Form or Slate Tensor" % type(eq.rhs).__name__)
510-
if len(eq.rhs.arguments()) != 1:
509+
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)):
510+
raise TypeError("Provided BC RHS is a '%s', not a BaseForm or Slate Tensor" % type(L).__name__)
511+
if len(L.arguments()) != 1:
511512
raise ValueError("Provided BC RHS is not a linear form")
512-
F = ufl_expr.action(J, u) - eq.rhs
513+
F = ufl_expr.action(J, u) - L
513514
self.is_linear = True
514515
# nonlinear
515516
else:
@@ -531,9 +532,7 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
531532
# reconstruction for splitting `solving_utils.split`
532533
self.Jp_eq_J = Jp_eq_J
533534
self.is_linear = is_linear
534-
self._F = args[0]
535-
self._J = args[1]
536-
self._Jp = args[2]
535+
self._F, self._J, self._Jp = args[:3]
537536
else:
538537
raise TypeError("Wrong EquationBC arguments")
539538

@@ -562,7 +561,7 @@ def reconstruct(self, V, subu, u, field, is_linear):
562561
if all([_F is not None, _J is not None, _Jp is not None]):
563562
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)
564563

565-
def _as_nonlinear_variational_problem_arg(self):
564+
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
566565
return self
567566

568567

@@ -654,19 +653,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
654653
ebc.add(bc_temp)
655654
return ebc
656655

657-
def _as_nonlinear_variational_problem_arg(self):
656+
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
658657
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
659658
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
660659
if len(self.f.arguments()) != 2:
661660
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
662661
J = self.f
663662
Vcol = J.arguments()[-1].function_space()
664663
u = firedrake.Function(Vcol)
665-
F = ufl_expr.action(J, u)
666664
Vrow = self._function_space
667665
sub_domain = self.sub_domain
668-
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
669-
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)
666+
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=is_linear) for bc in self.bcs)
667+
lhs = J if is_linear else ufl_expr.action(J, u)
668+
rhs = ufl.Form([]) if is_linear else 0
669+
return EquationBC(lhs == rhs, u, sub_domain, bcs=bcs, J=J, V=Vrow)
670670

671671

672672
@PETSc.Log.EventDecorator()

0 commit comments

Comments
 (0)