Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dualspace update #2294

Merged
merged 180 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 174 commits
Commits
Show all changes
180 commits
Select commit Hold shift + click to select a range
c8f8ff7
Use dualspace ufl branch in ci
indiamai Apr 21, 2021
0b56ea4
Use dualspace ufl branch in ci
indiamai Apr 21, 2021
576d22f
Merge branch 'dualspace' of https://github.com/firedrakeproject/fired…
indiamai May 5, 2021
6a1ce4e
loose first draft of assemble changes
indiamai May 11, 2021
afa7f6a
further attempts at firedrake cofunction
indiamai May 17, 2021
352a3fe
Assembling Formsum + tests
indiamai May 19, 2021
056946b
Add assembled matrix
indiamai May 21, 2021
50724d3
update assertion in linear solve
indiamai May 21, 2021
bed449e
Basic tests passing, add dual with geometry
indiamai May 24, 2021
febc6bc
Clean up + add matrix-matrix action
indiamai May 26, 2021
5f3b5be
Isolate broken test outside pytest temporarily
indiamai May 26, 2021
88253ce
Fix assemble issue and comment
indiamai Jun 2, 2021
fec4a8a
Lint (finally)
indiamai Jun 2, 2021
c3e1f63
tidying up
indiamai Jun 10, 2021
f0511e9
fixing merge conflicts
colinjcotter Jul 6, 2021
3679674
Use dualspace ufl branch in actions workflow
colinjcotter Jul 6, 2021
0f073f2
fixing lint
colinjcotter Jul 6, 2021
87188ad
Merge master
nbouziani Nov 29, 2021
a04c3c5
Add firedrake.Coargument
nbouziani Nov 30, 2021
b769f44
Remove ABCMeta from MatrixBase
nbouziani Nov 30, 2021
f9229d8
Add UFL branch
nbouziani Nov 30, 2021
1c7c56b
Cleanup
nbouziani Nov 30, 2021
8f71503
Workaround for UFLType
nbouziani Dec 1, 2021
8bdc83f
Fix test_assemble_formbase.py
nbouziani Dec 1, 2021
4bd8412
Fix split for Cofunction
nbouziani Dec 1, 2021
61eb6ef
Fix pyadjoint (chapter 1)
nbouziani Dec 6, 2021
a84b3f6
Update Cofunction assign + add test
nbouziani Dec 7, 2021
82d8557
Fix test_quadrature
nbouziani Dec 7, 2021
9fa651e
Fix syntax for Cofunction
nbouziani Dec 7, 2021
57e7795
Add ConstantValue case in Cofunction.assign
nbouziani Dec 7, 2021
f4f9c61
Fix sub function in solving_utils
nbouziani Dec 9, 2021
53de7c1
Fix few things with pyadjoint
nbouziani Dec 14, 2021
930082e
Fix pyadjoint
nbouziani Dec 14, 2021
86c7943
Fix firedrake tests (chapter 1)
nbouziani Dec 15, 2021
0e70cc9
Update docs.yml
nbouziani Dec 15, 2021
9186f4c
Add node_set to Cofunction
nbouziani Dec 15, 2021
e38f2bd
Merge branch 'dualspace_update' of github.com:firedrakeproject/firedr…
nbouziani Dec 15, 2021
58c42ec
Update docs.yml
nbouziani Dec 15, 2021
44bf5bb
Other attempt to fix docs.yml
nbouziani Dec 15, 2021
3d7f5d6
Fix firedrake tests (chapter 2)
nbouziani Dec 15, 2021
cad1509
Fix firedrake tests (chapter 3)
nbouziani Dec 15, 2021
43cfe2f
Modify AssembledVector
nbouziani Dec 15, 2021
94cc748
Slate: assembled vector data can be a cofunction now.
sv2518 Dec 15, 2021
abb1eaf
Cofunction: introduce cell_node_map function
sv2518 Dec 15, 2021
f833e78
Slate: check fo cofunction not basecoefficient in the coeff map gener…
sv2518 Dec 15, 2021
edbc52f
Expunge UFLType from Firedrake
nbouziani Dec 18, 2021
4a93f46
Fix Interpolate block
nbouziani Dec 23, 2021
cd278c0
Fix notebook 11 by computing the riesz representer of the adjoint
nbouziani Jan 16, 2022
2f2cd66
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jan 16, 2022
fee731e
Update docs.yml
nbouziani Jan 16, 2022
b7651d9
Docstring whitespace
nbouziani Jan 16, 2022
ada1ce4
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jan 21, 2022
3f7ed62
Resolve conflicts (merge master)
nbouziani Jul 12, 2022
0f025be
Fix lint
nbouziani Jul 12, 2022
33ea7c4
Fix UFL branch name (dualspace_update got merged)
nbouziani Jul 12, 2022
e4743b8
Clean *.bib files
nbouziani Jul 12, 2022
f3d252b
Merge master
nbouziani Jul 13, 2022
40f29d5
Update assembly doc
nbouziani Jul 15, 2022
7a7b60f
Fix bug in FunctionMergeBlock
nbouziani Jul 15, 2022
08b9ef6
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Jul 20, 2022
fdb7a25
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jul 20, 2022
757826a
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Jul 21, 2022
a51f5a5
Merge dualspace branch
nbouziani Jul 21, 2022
677a717
Add form assembly dispatcher
nbouziani Jul 21, 2022
0760cf5
Fix function space check in _get_map
nbouziani Jul 22, 2022
9002332
Make ImplicitMatrixContext compatible with Cofunction
nbouziani Jul 20, 2022
632e930
Fix few things
nbouziani Jul 22, 2022
b1717d2
Fix failing tests
nbouziani Jul 22, 2022
2a312d1
Update check_pde_args
nbouziani Jul 27, 2022
773efeb
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jul 27, 2022
7573942
Update assembly of ufl.Matrix's sum
nbouziani Aug 3, 2022
12ad7f9
Fix lint E275
nbouziani Aug 3, 2022
c92ddfa
Overwrite equals for firedrake.Coargument
nbouziani Aug 10, 2022
31be19a
Add support for ZeroBaseForm
nbouziani Aug 14, 2022
2c4d74f
Fix Matrix's arguments
nbouziani Aug 11, 2022
2a8773c
Fix lint
nbouziani Aug 22, 2022
9615a93
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Aug 22, 2022
34adae7
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Oct 21, 2022
22a698f
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Oct 25, 2022
5538081
Merge master
nbouziani Jan 11, 2023
45de28c
Merge master
nbouziani Jan 11, 2023
1807422
Merge master
nbouziani Jan 24, 2023
fdd0016
Remove spurious diff
nbouziani Jan 24, 2023
fda7940
Merge master
nbouziani Jan 24, 2023
09f6b22
Merge dual
nbouziani Jan 24, 2023
f82f748
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Jan 25, 2023
713f388
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Jan 25, 2023
a47195b
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jan 26, 2023
26984a0
Establish distinction between public self.comm and internal private s…
nbouziani Jan 26, 2023
3e1ce95
Equip Cofunctions with a Riesz representation method to get the Riesz…
nbouziani Jan 26, 2023
e61b79c
Add conj handler to CoefficientCollector
nbouziani Jan 26, 2023
0bec966
Annotation of 1-forms operates on Cofunction
nbouziani Jan 26, 2023
d3c0ec8
Fix lack of annotations for Riesz representation of Cofunctions
nbouziani Jan 26, 2023
7d3edd1
Remove conj handler in CoefficientCollector
nbouziani Jan 29, 2023
e5f5331
Remove UFL branch
nbouziani Mar 18, 2023
621478c
Merge master
nbouziani Mar 18, 2023
e34af65
Merge master
nbouziani Mar 18, 2023
e90d257
Fix decorator typo
nbouziani Mar 19, 2023
81eccd4
Fix decorator typo
nbouziani Mar 19, 2023
b5ae54e
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Mar 19, 2023
f7604d1
Clean up
nbouziani Mar 19, 2023
8d58ede
Fix doc
nbouziani Mar 19, 2023
61f235a
Fix indent
nbouziani Mar 19, 2023
95d72d0
Merge master
nbouziani May 29, 2023
ef9b314
Merge master
nbouziani May 29, 2023
5d0a9c7
Merge dual
nbouziani May 29, 2023
17434ca
Merge master
nbouziani Jun 18, 2023
ba224b9
Fix domain extractions for parloop's arguments
nbouziani Jun 21, 2023
bffc2f8
Extract domains: check if Constant is the only edge case
nbouziani Jun 22, 2023
b634289
Finish _extract_domains
nbouziani Jun 24, 2023
4afc69f
Address comments
nbouziani Jun 24, 2023
6d510a3
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jun 24, 2023
43731d7
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Jun 25, 2023
22a77ee
Wrap extract_unique_domain for Cofunctions
nbouziani Jun 25, 2023
9a4ea31
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Jun 25, 2023
3a8b42b
Fix tests
nbouziani Jun 25, 2023
c901874
Extend action assembly to 1-form against 1-form
nbouziani Jun 25, 2023
e480e7f
Add imul for cofunction + fix pytorch type check
nbouziani Jun 26, 2023
20f115e
Fix some tests
nbouziani Jun 26, 2023
ffa228e
Fix typo
nbouziani Jun 26, 2023
dca4695
Support solve(a == L, u) with L a Cofunction
nbouziani Jun 29, 2023
9e0cdf5
Fix tensor update for base form assembly
nbouziani Jun 29, 2023
d573ffd
Fix some tests
nbouziani Jun 29, 2023
9a62d94
Add fix_dual ufl branch
nbouziani Jun 29, 2023
e561bcd
Fix riesz representation's annotation: solve with a rhs Cofunction
nbouziani Jul 4, 2023
0c543ca
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jul 6, 2023
88149f6
Extend derivative check to Cofunction
nbouziani Jul 6, 2023
f4dcc54
Add zero-order FormSum assembly
nbouziani Jul 6, 2023
631ed9e
Transmit weight kwarg through base form assembly visitor
nbouziani Jul 6, 2023
680b223
Minor fixes
nbouziani Jul 7, 2023
5b6dc24
FiredrakeTorchOperator: Take the dual output space as the adjoint inp…
nbouziani Jul 13, 2023
b5b8fe9
Extend riesz representation mapping to Function for l2 case
nbouziani Jul 13, 2023
8f8de5b
Notebook 11: Update the way of getting riesz representer
nbouziani Jul 13, 2023
8d18cfd
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Jul 13, 2023
d5ca8e6
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Jul 14, 2023
921d4a4
Adapt FunctionMixin._ad_mul to cofunctions
nbouziani Jul 14, 2023
a8018b8
Update pointwise division in linear wave equation demo
nbouziani Jul 14, 2023
1ccde41
FiredrakeTorchOperator: Simplify constant adjoint input to sidestep c…
nbouziani Jul 14, 2023
27728ed
Fix Coargument form argument analysis
nbouziani Jul 14, 2023
9bd6c75
Coarguments have one argument in the primal space and one in the dual…
nbouziani Jul 18, 2023
a7d4932
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Jul 18, 2023
2bb9d5b
Merge remote-tracking branch 'origin/JDBetteridge/installation_hotfix…
nbouziani Jul 18, 2023
7c0a721
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Aug 20, 2023
4cf92f3
Merge dualspace
nbouziani Aug 23, 2023
9728463
Move Pyadjoint dual space changes to Firedrake
nbouziani Aug 23, 2023
6023cee
Fix FunctionMixin import
nbouziani Aug 25, 2023
2cf9863
Update function used to extract unique domain during assembly
nbouziani Aug 26, 2023
7b877d8
Fix TLM of Assemble in spatial coordinate case
nbouziani Sep 5, 2023
a235aa7
Add comment
nbouziani Sep 5, 2023
13d3722
Extend assembly arguments type to ufl.Argument
nbouziani Sep 7, 2023
db31d7e
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Sep 7, 2023
a42f7fd
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Sep 7, 2023
f57354e
Assemble block: fix ZeroBaseForm output for adjoint of spatialcoordin…
nbouziani Sep 7, 2023
9a0ab79
Lift some fixes from external operator branch
nbouziani Sep 10, 2023
75c2056
Fix a few things
nbouziani Sep 12, 2023
c188c16
mer
nbouziani Sep 12, 2023
436aef0
Strenghten criterion on preprocessing assemble_base_form input
nbouziani Sep 13, 2023
62ebd19
Fix docs test (hopefully)
nbouziani Sep 13, 2023
93e65a2
Fix docs (attempt 2)
nbouziani Sep 13, 2023
07f4628
Update Cofunction's docs
nbouziani Sep 14, 2023
623917b
Expunge Vector from firedrake.adjoint: adjoint/tlm/hessian values are…
nbouziani Sep 18, 2023
8f5da43
Fix WithGeometryBase docs
nbouziani Sep 18, 2023
49f4a13
Add riesz_representation to Function objects to get the corresponding…
nbouziani Sep 18, 2023
e54db96
Merge remote-tracking branch 'origin/master' into dualspace
nbouziani Sep 18, 2023
e3727fa
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Sep 18, 2023
f537eaf
Delete tests/regression/isolated_auxilary.py
nbouziani Sep 19, 2023
7f37ca0
Merge remote-tracking branch 'origin/dualspace' into dualspace_update
nbouziani Sep 19, 2023
c5aff03
Clean up output of 11-extract-adjoint-solutions.ipynb
nbouziani Sep 19, 2023
bcd0185
Clean up
nbouziani Sep 19, 2023
4c0258c
Merge remote-tracking branch 'origin/master' into dualspace_update
nbouziani Sep 19, 2023
5e85161
Don't preprocess slate.TensorBase form
nbouziani Sep 19, 2023
fda4a7e
Clean up
nbouziani Sep 20, 2023
9cb03d7
Additional cleaning
nbouziani Sep 20, 2023
d3d3694
Update .github/workflows/build.yml
nbouziani Sep 20, 2023
f2bf456
Update firedrake/cofunction.py
nbouziani Sep 20, 2023
291cd99
Update multifunction import (recently changed in UFL upstream)
nbouziani Sep 20, 2023
6935f33
Address PR's comments
nbouziani Sep 20, 2023
0b8a0c2
Fix bcs tests
nbouziani Sep 21, 2023
4bda0d0
Add UFL branch that fixes the pyadjoint tests
nbouziani Sep 21, 2023
c550838
Update .github/workflows/build.yml
nbouziani Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demos/linear-wave-equation/linear_wave_equation.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ options at this point, we may either `lump` the mass, which reduces
the inversion to a pointwise division::

if lump_mass:
p += interpolate(assemble(dt * inner(nabla_grad(v), nabla_grad(phi))*dx) / assemble(v*dx), V)
p.dat.data[:] += assemble(dt * inner(nabla_grad(v), nabla_grad(phi))*dx).dat.data_ro / assemble(v*dx).dat.data_ro

In the mass lumped case, we must now ensure that the resulting
solution for :math:`p` satisfies the boundary conditions::
Expand Down
2 changes: 2 additions & 0 deletions docs/notebooks/11-extract-adjoint-solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,8 @@
" t = i*timesteps_per_export*dt\n",
" tricontourf(forward_solutions[i], axes=axs[i, 0])\n",
" adjoint_solution = dJdu if i == num_exports else solve_blocks[timesteps_per_export*i].adj_sol\n",
" # Get the Riesz representer\n",
" adjoint_solution = dJdu.riesz_representation(riesz_map=\"H1\")\n",
" tricontourf(adjoint_solution, axes=axs[i, 1])\n",
" axs[i, 0].annotate('t={:.2f}'.format(t), (0.05, 0.05), color='white');\n",
" axs[i, 1].annotate('t={:.2f}'.format(t), (0.05, 0.05), color='white');\n",
Expand Down
1 change: 1 addition & 0 deletions firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from firedrake.assemble import *
from firedrake.bcs import *
from firedrake.checkpointing import *
from firedrake.cofunction import *
from firedrake.constant import *
from firedrake.exceptions import *
from firedrake.function import *
Expand Down
8 changes: 5 additions & 3 deletions firedrake/adjoint_utils/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def wrapper(*args, **kwargs):
output = assemble(*args, **kwargs)

from firedrake.function import Function
from firedrake.cofunction import Cofunction
form = args[0]
if isinstance(output, (numbers.Complex, Function)):
if isinstance(output, (numbers.Complex, Function, Cofunction)):
# Assembling a 0-form or 1-form (e.g. Form)
if not annotate:
return output

if not isinstance(output, (float, Function)):
if not isinstance(output, (float, Function, Cofunction)):
raise NotImplementedError("Taping for complex-valued 0-forms not yet done!")
output = create_overloaded_object(output)
block = AssembleBlock(form, ad_block_tag=ad_block_tag)
Expand All @@ -34,7 +36,7 @@ def wrapper(*args, **kwargs):

block.add_output(output.block_variable)
else:
# Assembled a matrix
# Assembled a 2-form
output.form = form

return output
Expand Down
57 changes: 31 additions & 26 deletions firedrake/adjoint_utils/blocks/assembly.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ufl
import firedrake
from ufl.formatting.ufl2unicode import ufl2unicode
from pyadjoint import Block, create_overloaded_object
from pyadjoint import Block, AdjFloat, create_overloaded_object
from .backend import Backend
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint

Expand All @@ -13,8 +13,11 @@ def __init__(self, form, ad_block_tag=None):
if self.backend.__name__ != "firedrake":
mesh = self.form.ufl_domain().ufl_cargo()
else:
mesh = self.form.ufl_domain()
self.add_dependency(mesh)
mesh = self.form.ufl_domain() if hasattr(self.form, 'ufl_domain') else None

if mesh:
self.add_dependency(mesh)

for c in self.form.coefficients():
self.add_dependency(c, no_duplicates=True)

Expand All @@ -28,7 +31,7 @@ def compute_action_adjoint(self, adj_input, arity_form, form=None,
`<(dform/dc_rep)*, adj_input>`

- If `form` has arity 0 => `dform/dc_rep` is a 1-form and
`adj_input` a foat, we can simply use the `*` operator.
`adj_input` a float, we can simply use the `*` operator.

- If `form` has arity 1 => `dform/dc_rep` is a 2-form and we can
symbolically take its adjoint and then apply the action on
Expand All @@ -38,31 +41,38 @@ def compute_action_adjoint(self, adj_input, arity_form, form=None,
if dform is None:
dc = self.backend.TestFunction(space)
dform = self.backend.derivative(form, c_rep, dc)
dform_vector = self.compat.assemble_adjoint_value(dform)
# Return a Vector scaled by the scalar `adj_input`
return dform_vector * adj_input, dform
dform_adj = self.compat.assemble_adjoint_value(dform)
if dform_adj == 0:
# `dform_adj` is a `ZeroBaseForm`
return AdjFloat(0.), dform
# Return the adjoint model of `form` scaled by the scalar `adj_input`
adj_output = dform_adj._ad_mul(adj_input)
return adj_output, dform
elif arity_form == 1:
if dform is None:
dc = self.backend.TrialFunction(space)
dform = self.backend.derivative(form, c_rep, dc)
# Get the Function
adj_input = adj_input.function
# Symbolic operators such as action/adjoint require derivatives to
# have been expanded beforehand. However, UFL doesn't support
# expanding coordinate derivatives of Coefficients in physical
# space, implying that we can't symbolically take the
# action/adjoint of the Jacobian for SpatialCoordinates. ->
# Workaround: Apply action/adjoint numerically (using PETSc).
# action/adjoint of the Jacobian for SpatialCoordinates.
# -> Workaround: Apply action/adjoint numerically (using PETSc).
if not isinstance(c_rep, self.backend.SpatialCoordinate):
# Symbolically compute: (dform/dc_rep)^* * adj_input
adj_output = self.backend.action(self.backend.adjoint(dform),
adj_input)
adj_output = self.compat.assemble_adjoint_value(adj_output)
else:
adj_output = self.backend.Cofunction(space.dual())
# Assemble `dform`: derivatives are expanded along the way
# which may lead to a ZeroBaseForm
assembled_dform = self.compat.assemble_adjoint_value(dform)
if assembled_dform == 0:
return adj_output, dform
# Get PETSc matrix
dform_mat = self.compat.assemble_adjoint_value(dform).petscmat
dform_mat = assembled_dform.petscmat
# Action of the adjoint (Hermitian transpose)
adj_output = self.backend.Function(space)
with adj_input.dat.vec_ro as v_vec:
with adj_output.dat.vec as res_vec:
dform_mat.multHermitian(v_vec, res_vec)
Expand Down Expand Up @@ -105,7 +115,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
if self.compat.isconstant(c):
mesh = self.compat.extract_mesh_from_form(self.form)
space = c._ad_function_space(mesh)
elif isinstance(c, self.backend.Function):
elif isinstance(c, (self.backend.Function, self.backend.Cofunction)):
space = c.function_space()
elif isinstance(c, self.compat.MeshType):
c_rep = self.backend.SpatialCoordinate(c_rep)
Expand All @@ -123,8 +133,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
form = prepared
dform = 0.

from ufl.algorithms.analysis import extract_arguments
arity_form = len(extract_arguments(form))
for bv in self.get_dependencies():
c_rep = bv.saved_output
tlm_value = bv.tlm_value
Expand All @@ -133,15 +141,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
continue
if isinstance(c_rep, self.compat.MeshType):
X = self.backend.SpatialCoordinate(c_rep)
# Spatial coordinates derivatives cannot be expanded in the physical space,
# which is required by symbolic operators such as `action`.
dform += self.backend.derivative(form, X, tlm_value)
else:
dform += self.backend.derivative(form, c_rep, tlm_value)
dform += self.backend.action(self.backend.derivative(form, c_rep), tlm_value)
if not isinstance(dform, float):
dform = ufl.algorithms.expand_derivatives(dform)
dform = self.compat.assemble_adjoint_value(dform)
if arity_form == 1 and dform != 0:
# Then dform is a Vector
dform = dform.function
return dform

def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
Expand All @@ -165,7 +172,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
if self.compat.isconstant(c1):
mesh = self.compat.extract_mesh_from_form(form)
space = c1._ad_function_space(mesh)
elif isinstance(c1, self.backend.Function):
elif isinstance(c1, (self.backend.Function, self.backend.Cofunction)):
space = c1.function_space()
elif isinstance(c1, self.compat.ExpressionType):
mesh = form.ufl_domain().ufl_cargo()
Expand All @@ -180,7 +187,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
hessian_input, arity_form, form, c1_rep, space
)

ddform = 0
ddform = 0.
for other_idx, bv in relevant_dependencies:
c2_rep = bv.saved_output
tlm_input = bv.tlm_value
Expand All @@ -196,10 +203,8 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,

if not isinstance(ddform, float):
ddform = ufl.algorithms.expand_derivatives(ddform)
if not ddform.empty():
hessian_outputs += self.compute_action_adjoint(
adj_input, arity_form, dform=ddform
)[0]
if not (isinstance(ddform, ufl.ZeroBaseForm) or (isinstance(ddform, ufl.Form) and ddform.empty())):
hessian_outputs += self.compute_action_adjoint(adj_input, arity_form, dform=ddform)[0]

if isinstance(c1, self.compat.ExpressionType):
return [(hessian_outputs, space)]
Expand Down
12 changes: 2 additions & 10 deletions firedrake/adjoint_utils/blocks/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def extract_bc_subvector(value, Vtarget, bc):
for idx in bc._indices:
r = r.sub(idx)
assert Vtarget == r.function_space()
return r.vector()
return r
compat.extract_bc_subvector = extract_bc_subvector

def extract_mesh_from_form(form):
Expand Down Expand Up @@ -115,15 +115,7 @@ def constant_function_firedrake_compat(value):
return value.dat.data
compat.constant_function_firedrake_compat = constant_function_firedrake_compat

def assemble_adjoint_value(*args, **kwargs):
"""A wrapper around Firedrake's assemble that returns a Vector
instead of a Function when assembling a 1-form."""
result = backend.assemble(*args, **kwargs)
if isinstance(result, backend.Function):
return result.vector()
else:
return result
compat.assemble_adjoint_value = assemble_adjoint_value
compat.assemble_adjoint_value = backend.assemble

def gather(vec):
return vec.gather()
Expand Down
7 changes: 4 additions & 3 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if self.compat.isconstant(c):
adj_value = self.backend.Function(self.parent_space)
adj_value = self.backend.Function(self.parent_space.dual())
adj_input.apply(adj_value.vector())
if self.function_space != self.parent_space:
vec = self.compat.extract_bc_subvector(
Expand Down Expand Up @@ -77,13 +77,13 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = self.backend.Function(self.parent_space)
adj_value = self.backend.Function(self.parent_space.dual())
adj_input.apply(adj_value.vector())
r = self.compat.extract_bc_subvector(
adj_value, c.function_space(), bc
)
elif isinstance(c, self.compat.Expression):
adj_value = self.backend.Function(self.parent_space)
adj_value = self.backend.Function(self.parent_space.dual())
adj_input.apply(adj_value.vector())
output = self.compat.extract_bc_subvector(
adj_value, self.collapsed_space, bc
Expand All @@ -93,6 +93,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = r
else:
adj_output += r

return adj_output

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
Expand Down
Loading
Loading