@@ -95,6 +95,10 @@ def assemble(expr, *args, **kwargs):
95
95
`matrix.Matrix`.
96
96
is_base_form_preprocessed : bool
97
97
If `True`, skip preprocessing of the form.
98
+ current_state : firedrake.function.Function or None
99
+ If provided and ``zero_bc_nodes == False``, the boundary condition
100
+ nodes of the output are set to the residual of the boundary conditions
101
+ computed as ``current_state`` minus the boundary condition value.
98
102
99
103
Returns
100
104
-------
@@ -130,16 +134,21 @@ def assemble(expr, *args, **kwargs):
130
134
"""
131
135
if args :
132
136
raise RuntimeError (f"Got unexpected args: { args } " )
133
- tensor = kwargs .pop ("tensor" , None )
134
- return get_assembler (expr , * args , ** kwargs ).assemble (tensor = tensor )
137
+
138
+ assemble_kwargs = {}
139
+ for key in ("tensor" , "current_state" ):
140
+ if key in kwargs :
141
+ assemble_kwargs [key ] = kwargs .pop (key , None )
142
+ return get_assembler (expr , * args , ** kwargs ).assemble (** assemble_kwargs )
135
143
136
144
137
145
def get_assembler (form , * args , ** kwargs ):
138
146
"""Create an assembler.
139
147
140
148
Notes
141
149
-----
142
- See `assemble` for descriptions of the parameters. ``tensor`` should not be passed to this function.
150
+ See `assemble` for descriptions of the parameters. ``tensor`` and
151
+ ``current_state`` should not be passed to this function.
143
152
144
153
"""
145
154
is_base_form_preprocessed = kwargs .pop ('is_base_form_preprocessed' , False )
@@ -187,13 +196,15 @@ class ExprAssembler(object):
187
196
def __init__ (self , expr ):
188
197
self ._expr = expr
189
198
190
- def assemble (self , tensor = None ):
199
+ def assemble (self , tensor = None , current_state = None ):
191
200
"""Assemble the pointwise expression.
192
201
193
202
Parameters
194
203
----------
195
204
tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
196
205
Output tensor.
206
+ current_state : None
207
+ Ignored by this class.
197
208
198
209
Returns
199
210
-------
@@ -205,6 +216,7 @@ def assemble(self, tensor=None):
205
216
from ufl .checks import is_scalar_constant_expression
206
217
207
218
assert tensor is None
219
+ assert current_state is None
208
220
expr = self ._expr
209
221
# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
210
222
base_form_operators = extract_base_form_operators (expr )
@@ -274,13 +286,16 @@ def allocate(self):
274
286
"""Allocate memory for the output tensor."""
275
287
276
288
@abc .abstractmethod
277
- def assemble (self , tensor = None ):
289
+ def assemble (self , tensor = None , current_state = None ):
278
290
"""Assemble the form.
279
291
280
292
Parameters
281
293
----------
282
294
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
283
295
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
296
+ current_state : firedrake.function.Function or None
297
+ If provided, the boundary condition nodes are set to the boundary condition residual
298
+ computed as ``current_state`` minus the boundary condition value.
284
299
285
300
Returns
286
301
-------
@@ -358,13 +373,16 @@ def allocation_integral_types(self):
358
373
else :
359
374
return self ._allocation_integral_types
360
375
361
- def assemble (self , tensor = None ):
376
+ def assemble (self , tensor = None , current_state = None ):
362
377
"""Assemble the form.
363
378
364
379
Parameters
365
380
----------
366
381
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
367
382
Output tensor to contain the result of assembly.
383
+ current_state : firedrake.function.Function or None
384
+ If provided, the boundary condition nodes are set to the boundary condition residual
385
+ computed as ``current_state`` minus the boundary condition value.
368
386
369
387
Returns
370
388
-------
@@ -389,7 +407,7 @@ def visitor(e, *operands):
389
407
rank = len (self ._form .arguments ())
390
408
if rank == 1 and not isinstance (result , ufl .ZeroBaseForm ):
391
409
for bc in self ._bcs :
392
- bc . zero ( result )
410
+ OneFormAssembler . _apply_bc ( self , result , bc , u = current_state )
393
411
394
412
if tensor :
395
413
BaseFormAssembler .update_tensor (result , tensor )
@@ -968,13 +986,16 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
968
986
super ().__init__ (form , bcs = bcs , form_compiler_parameters = form_compiler_parameters )
969
987
self ._needs_zeroing = needs_zeroing
970
988
971
- def assemble (self , tensor = None ):
989
+ def assemble (self , tensor = None , current_state = None ):
972
990
"""Assemble the form.
973
991
974
992
Parameters
975
993
----------
976
994
tensor : firedrake.cofunction.Cofunction or matrix.MatrixBase
977
995
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
996
+ current_state : firedrake.function.Function or None
997
+ If provided, the boundary condition nodes are set to the boundary condition residual
998
+ computed as ``current_state`` minus the boundary condition value.
978
999
979
1000
Returns
980
1001
-------
@@ -998,12 +1019,12 @@ def assemble(self, tensor=None):
998
1019
self .execute_parloops (tensor )
999
1020
1000
1021
for bc in self ._bcs :
1001
- self ._apply_bc (tensor , bc )
1022
+ self ._apply_bc (tensor , bc , u = current_state )
1002
1023
1003
1024
return self .result (tensor )
1004
1025
1005
1026
@abc .abstractmethod
1006
- def _apply_bc (self , tensor , bc ):
1027
+ def _apply_bc (self , tensor , bc , u = None ):
1007
1028
"""Apply boundary condition."""
1008
1029
1009
1030
@abc .abstractmethod
@@ -1138,7 +1159,7 @@ def allocate(self):
1138
1159
comm = self ._form .ufl_domains ()[0 ]._comm
1139
1160
)
1140
1161
1141
- def _apply_bc (self , tensor , bc ):
1162
+ def _apply_bc (self , tensor , bc , u = None ):
1142
1163
pass
1143
1164
1144
1165
def _check_tensor (self , tensor ):
@@ -1199,26 +1220,29 @@ def allocate(self):
1199
1220
else :
1200
1221
raise RuntimeError (f"Not expected: found rank = { rank } and diagonal = { self ._diagonal } " )
1201
1222
1202
- def _apply_bc (self , tensor , bc ):
1223
+ def _apply_bc (self , tensor , bc , u = None ):
1203
1224
# TODO Maybe this could be a singledispatchmethod?
1204
1225
if isinstance (bc , DirichletBC ):
1205
- self ._apply_dirichlet_bc (tensor , bc )
1226
+ if self ._diagonal :
1227
+ bc .set (tensor , self ._weight )
1228
+ elif self ._zero_bc_nodes :
1229
+ bc .zero (tensor )
1230
+ else :
1231
+ # The residual belongs to a mixed space that is dual on the boundary nodes
1232
+ # and primal on the interior nodes. Therefore, this is a type-safe operation.
1233
+ r = tensor .riesz_representation ("l2" )
1234
+ bc .apply (r , u = u )
1206
1235
elif isinstance (bc , EquationBCSplit ):
1207
1236
bc .zero (tensor )
1208
- type (self )(bc .f , bcs = bc .bcs , form_compiler_parameters = self ._form_compiler_params , needs_zeroing = False ,
1209
- zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal , weight = self ._weight ).assemble (tensor = tensor )
1237
+ OneFormAssembler (bc .f , bcs = bc .bcs ,
1238
+ form_compiler_parameters = self ._form_compiler_params ,
1239
+ needs_zeroing = False ,
1240
+ zero_bc_nodes = self ._zero_bc_nodes ,
1241
+ diagonal = self ._diagonal ,
1242
+ weight = self ._weight ).assemble (tensor = tensor , current_state = u )
1210
1243
else :
1211
1244
raise AssertionError
1212
1245
1213
- def _apply_dirichlet_bc (self , tensor , bc ):
1214
- if self ._diagonal :
1215
- bc .set (tensor , self ._weight )
1216
- elif not self ._zero_bc_nodes :
1217
- # NOTE this only works if tensor is a Function and not a Cofunction
1218
- bc .apply (tensor )
1219
- else :
1220
- bc .zero (tensor )
1221
-
1222
1246
def _check_tensor (self , tensor ):
1223
1247
if tensor .function_space () != self ._form .arguments ()[0 ].function_space ().dual ():
1224
1248
raise ValueError ("Form's argument does not match provided result tensor" )
@@ -1430,7 +1454,8 @@ def _all_assemblers(self):
1430
1454
all_assemblers .extend (_assembler ._all_assemblers )
1431
1455
return tuple (all_assemblers )
1432
1456
1433
- def _apply_bc (self , tensor , bc ):
1457
+ def _apply_bc (self , tensor , bc , u = None ):
1458
+ assert u is None
1434
1459
op2tensor = tensor .M
1435
1460
spaces = tuple (a .function_space () for a in tensor .a .arguments ())
1436
1461
V = bc .function_space ()
@@ -1534,7 +1559,7 @@ def allocate(self):
1534
1559
options_prefix = self ._options_prefix ,
1535
1560
appctx = self ._appctx or {})
1536
1561
1537
- def assemble (self , tensor = None ):
1562
+ def assemble (self , tensor = None , current_state = None ):
1538
1563
if tensor is None :
1539
1564
tensor = self .allocate ()
1540
1565
else :
0 commit comments