diff --git a/demos/linear-wave-equation/linear_wave_equation.py.rst b/demos/linear-wave-equation/linear_wave_equation.py.rst index e97c268284..9d8c8cf119 100644 --- a/demos/linear-wave-equation/linear_wave_equation.py.rst +++ b/demos/linear-wave-equation/linear_wave_equation.py.rst @@ -107,7 +107,7 @@ options at this point, we may either `lump` the mass, which reduces the inversion to a pointwise division:: if lump_mass: - p += interpolate(assemble(dt * inner(nabla_grad(v), nabla_grad(phi))*dx) / assemble(v*dx), V) + p.dat.data[:] += assemble(dt * inner(nabla_grad(v), nabla_grad(phi))*dx).dat.data_ro / assemble(v*dx).dat.data_ro In the mass lumped case, we must now ensure that the resulting solution for :math:`p` satisfies the boundary conditions:: diff --git a/docs/notebooks/11-extract-adjoint-solutions.ipynb b/docs/notebooks/11-extract-adjoint-solutions.ipynb index d4a0b7560a..667a3f8f55 100644 --- a/docs/notebooks/11-extract-adjoint-solutions.ipynb +++ b/docs/notebooks/11-extract-adjoint-solutions.ipynb @@ -1227,6 +1227,8 @@ " t = i*timesteps_per_export*dt\n", " tricontourf(forward_solutions[i], axes=axs[i, 0])\n", " adjoint_solution = dJdu if i == num_exports else solve_blocks[timesteps_per_export*i].adj_sol\n", + " # Get the Riesz representer\n", + " adjoint_solution = dJdu.riesz_representation(riesz_map=\"H1\")\n", " tricontourf(adjoint_solution, axes=axs[i, 1])\n", " axs[i, 0].annotate('t={:.2f}'.format(t), (0.05, 0.05), color='white');\n", " axs[i, 1].annotate('t={:.2f}'.format(t), (0.05, 0.05), color='white');\n", diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 871235d6b7..6ac952e0ad 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -72,6 +72,7 @@ from firedrake.assemble import * from firedrake.bcs import * from firedrake.checkpointing import * +from firedrake.cofunction import * from firedrake.constant import * from firedrake.exceptions import * from firedrake.function import * diff --git a/firedrake/adjoint_utils/assembly.py b/firedrake/adjoint_utils/assembly.py index e596d3dc03..2cc4f0fc66 100644 --- a/firedrake/adjoint_utils/assembly.py +++ b/firedrake/adjoint_utils/assembly.py @@ -19,12 +19,14 @@ def wrapper(*args, **kwargs): output = assemble(*args, **kwargs) from firedrake.function import Function + from firedrake.cofunction import Cofunction form = args[0] - if isinstance(output, (numbers.Complex, Function)): + if isinstance(output, (numbers.Complex, Function, Cofunction)): + # Assembling a 0-form or 1-form (e.g. Form) if not annotate: return output - if not isinstance(output, (float, Function)): + if not isinstance(output, (float, Function, Cofunction)): raise NotImplementedError("Taping for complex-valued 0-forms not yet done!") output = create_overloaded_object(output) block = AssembleBlock(form, ad_block_tag=ad_block_tag) @@ -34,7 +36,7 @@ def wrapper(*args, **kwargs): block.add_output(output.block_variable) else: - # Assembled a matrix + # Assembled a 2-form output.form = form return output diff --git a/firedrake/adjoint_utils/blocks/assembly.py b/firedrake/adjoint_utils/blocks/assembly.py index bff5683393..730f81bcc6 100644 --- a/firedrake/adjoint_utils/blocks/assembly.py +++ b/firedrake/adjoint_utils/blocks/assembly.py @@ -1,7 +1,7 @@ import ufl import firedrake from ufl.formatting.ufl2unicode import ufl2unicode -from pyadjoint import Block, create_overloaded_object +from pyadjoint import Block, AdjFloat, create_overloaded_object from .backend import Backend from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint @@ -13,8 +13,11 @@ def __init__(self, form, ad_block_tag=None): if self.backend.__name__ != "firedrake": mesh = self.form.ufl_domain().ufl_cargo() else: - mesh = self.form.ufl_domain() - self.add_dependency(mesh) + mesh = self.form.ufl_domain() if hasattr(self.form, 'ufl_domain') else None + + if mesh: + self.add_dependency(mesh) + for c in self.form.coefficients(): self.add_dependency(c, no_duplicates=True) @@ -28,7 +31,7 @@ def compute_action_adjoint(self, adj_input, arity_form, form=None, `<(dform/dc_rep)*, adj_input>` - If `form` has arity 0 => `dform/dc_rep` is a 1-form and - `adj_input` a foat, we can simply use the `*` operator. + `adj_input` a float, we can simply use the `*` operator. - If `form` has arity 1 => `dform/dc_rep` is a 2-form and we can symbolically take its adjoint and then apply the action on @@ -38,31 +41,38 @@ def compute_action_adjoint(self, adj_input, arity_form, form=None, if dform is None: dc = self.backend.TestFunction(space) dform = self.backend.derivative(form, c_rep, dc) - dform_vector = self.compat.assemble_adjoint_value(dform) - # Return a Vector scaled by the scalar `adj_input` - return dform_vector * adj_input, dform + dform_adj = self.compat.assemble_adjoint_value(dform) + if dform_adj == 0: + # `dform_adj` is a `ZeroBaseForm` + return AdjFloat(0.), dform + # Return the adjoint model of `form` scaled by the scalar `adj_input` + adj_output = dform_adj._ad_mul(adj_input) + return adj_output, dform elif arity_form == 1: if dform is None: dc = self.backend.TrialFunction(space) dform = self.backend.derivative(form, c_rep, dc) - # Get the Function - adj_input = adj_input.function # Symbolic operators such as action/adjoint require derivatives to # have been expanded beforehand. However, UFL doesn't support # expanding coordinate derivatives of Coefficients in physical # space, implying that we can't symbolically take the - # action/adjoint of the Jacobian for SpatialCoordinates. -> - # Workaround: Apply action/adjoint numerically (using PETSc). + # action/adjoint of the Jacobian for SpatialCoordinates. + # -> Workaround: Apply action/adjoint numerically (using PETSc). if not isinstance(c_rep, self.backend.SpatialCoordinate): # Symbolically compute: (dform/dc_rep)^* * adj_input adj_output = self.backend.action(self.backend.adjoint(dform), adj_input) adj_output = self.compat.assemble_adjoint_value(adj_output) else: + adj_output = self.backend.Cofunction(space.dual()) + # Assemble `dform`: derivatives are expanded along the way + # which may lead to a ZeroBaseForm + assembled_dform = self.compat.assemble_adjoint_value(dform) + if assembled_dform == 0: + return adj_output, dform # Get PETSc matrix - dform_mat = self.compat.assemble_adjoint_value(dform).petscmat + dform_mat = assembled_dform.petscmat # Action of the adjoint (Hermitian transpose) - adj_output = self.backend.Function(space) with adj_input.dat.vec_ro as v_vec: with adj_output.dat.vec as res_vec: dform_mat.multHermitian(v_vec, res_vec) @@ -105,7 +115,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, if self.compat.isconstant(c): mesh = self.compat.extract_mesh_from_form(self.form) space = c._ad_function_space(mesh) - elif isinstance(c, self.backend.Function): + elif isinstance(c, (self.backend.Function, self.backend.Cofunction)): space = c.function_space() elif isinstance(c, self.compat.MeshType): c_rep = self.backend.SpatialCoordinate(c_rep) @@ -123,8 +133,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, form = prepared dform = 0. - from ufl.algorithms.analysis import extract_arguments - arity_form = len(extract_arguments(form)) for bv in self.get_dependencies(): c_rep = bv.saved_output tlm_value = bv.tlm_value @@ -133,15 +141,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, continue if isinstance(c_rep, self.compat.MeshType): X = self.backend.SpatialCoordinate(c_rep) + # Spatial coordinates derivatives cannot be expanded in the physical space, + # which is required by symbolic operators such as `action`. dform += self.backend.derivative(form, X, tlm_value) else: - dform += self.backend.derivative(form, c_rep, tlm_value) + dform += self.backend.action(self.backend.derivative(form, c_rep), tlm_value) if not isinstance(dform, float): dform = ufl.algorithms.expand_derivatives(dform) dform = self.compat.assemble_adjoint_value(dform) - if arity_form == 1 and dform != 0: - # Then dform is a Vector - dform = dform.function return dform def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, @@ -165,7 +172,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, if self.compat.isconstant(c1): mesh = self.compat.extract_mesh_from_form(form) space = c1._ad_function_space(mesh) - elif isinstance(c1, self.backend.Function): + elif isinstance(c1, (self.backend.Function, self.backend.Cofunction)): space = c1.function_space() elif isinstance(c1, self.compat.ExpressionType): mesh = form.ufl_domain().ufl_cargo() @@ -180,7 +187,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, hessian_input, arity_form, form, c1_rep, space ) - ddform = 0 + ddform = 0. for other_idx, bv in relevant_dependencies: c2_rep = bv.saved_output tlm_input = bv.tlm_value @@ -196,10 +203,8 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, if not isinstance(ddform, float): ddform = ufl.algorithms.expand_derivatives(ddform) - if not ddform.empty(): - hessian_outputs += self.compute_action_adjoint( - adj_input, arity_form, dform=ddform - )[0] + if not (isinstance(ddform, ufl.ZeroBaseForm) or (isinstance(ddform, ufl.Form) and ddform.empty())): + hessian_outputs += self.compute_action_adjoint(adj_input, arity_form, dform=ddform)[0] if isinstance(c1, self.compat.ExpressionType): return [(hessian_outputs, space)] diff --git a/firedrake/adjoint_utils/blocks/compat.py b/firedrake/adjoint_utils/blocks/compat.py index 8196419a11..76f7d3c830 100644 --- a/firedrake/adjoint_utils/blocks/compat.py +++ b/firedrake/adjoint_utils/blocks/compat.py @@ -77,7 +77,7 @@ def extract_bc_subvector(value, Vtarget, bc): for idx in bc._indices: r = r.sub(idx) assert Vtarget == r.function_space() - return r.vector() + return r compat.extract_bc_subvector = extract_bc_subvector def extract_mesh_from_form(form): @@ -115,15 +115,7 @@ def constant_function_firedrake_compat(value): return value.dat.data compat.constant_function_firedrake_compat = constant_function_firedrake_compat - def assemble_adjoint_value(*args, **kwargs): - """A wrapper around Firedrake's assemble that returns a Vector - instead of a Function when assembling a 1-form.""" - result = backend.assemble(*args, **kwargs) - if isinstance(result, backend.Function): - return result.vector() - else: - return result - compat.assemble_adjoint_value = assemble_adjoint_value + compat.assemble_adjoint_value = backend.assemble def gather(vec): return vec.gather() diff --git a/firedrake/adjoint_utils/blocks/dirichlet_bc.py b/firedrake/adjoint_utils/blocks/dirichlet_bc.py index 3cc1cfc957..c913dbf307 100644 --- a/firedrake/adjoint_utils/blocks/dirichlet_bc.py +++ b/firedrake/adjoint_utils/blocks/dirichlet_bc.py @@ -39,18 +39,16 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, adj_output = None for adj_input in adj_inputs: if self.compat.isconstant(c): - adj_value = self.backend.Function(self.parent_space) - adj_input.apply(adj_value.vector()) + adj_value = self.backend.Function(self.parent_space.dual()) + adj_input.apply(adj_value) if self.function_space != self.parent_space: vec = self.compat.extract_bc_subvector( adj_value, self.collapsed_space, bc ) - adj_value = self.compat.function_from_vector( - self.collapsed_space, vec - ) + adj_value = self.backend.Function(self.collapsed_space, vec.dat) if adj_value.ufl_shape == () or adj_value.ufl_shape[0] <= 1: - r = adj_value.vector().sum() + r = adj_value.dat.data_ro.sum() else: output = [] subindices = _extract_subindices(self.function_space) @@ -64,7 +62,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, output.append( current_subfunc.sub( prev_idx, deepcopy=True - ).vector().sum() + ).dat.data_ro.sum() ) r = self.backend.cpp.la.Vector(self.backend.MPI.comm_world, @@ -77,14 +75,14 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, # you can even use the Function outside its domain. # For now we will just assume the FunctionSpace is the same for # the BC and the Function. - adj_value = self.backend.Function(self.parent_space) - adj_input.apply(adj_value.vector()) + adj_value = self.backend.Function(self.parent_space.dual()) + adj_input.apply(adj_value) r = self.compat.extract_bc_subvector( adj_value, c.function_space(), bc ) elif isinstance(c, self.compat.Expression): - adj_value = self.backend.Function(self.parent_space) - adj_input.apply(adj_value.vector()) + adj_value = self.backend.Function(self.parent_space.dual()) + adj_input.apply(adj_value) output = self.compat.extract_bc_subvector( adj_value, self.collapsed_space, bc ) @@ -93,6 +91,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, adj_output = r else: adj_output += r + return adj_output def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index 49cb86cf55..8368ffb3f7 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -2,7 +2,7 @@ from ufl import replace from ufl.corealg.traversal import traverse_unique_terminals from ufl.formatting.ufl2unicode import ufl2unicode -from ufl.algorithms.analysis import extract_arguments_and_coefficients +from ufl.algorithms.analysis import extract_arguments, extract_arguments_and_coefficients from pyadjoint import Block, OverloadedType, AdjFloat import firedrake from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint, \ @@ -38,8 +38,9 @@ def _replace_with_saved_output(self): return ufl.replace(self.expr, replace_map) def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - V = self.get_outputs()[0].output.function_space() - adj_input_func = self.compat.function_from_vector(V, adj_inputs[0]) + adj_input_func, = adj_inputs + if isinstance(adj_input_func, self.backend.Cofunction): + adj_input_func = adj_input_func.riesz_representation(riesz_map="l2") if self.expr is None: return adj_input_func @@ -53,7 +54,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, if isinstance(block_variable.output, AdjFloat): try: # Adjoint of a broadcast is just a sum - return adj_inputs[0].sum() + return adj_inputs[0].dat.data_ro.sum() except AttributeError: # Catch the case where adj_inputs[0] is just a float return adj_inputs[0] @@ -66,7 +67,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, adj_output = self.backend.Function( block_variable.output.function_space()) adj_output.assign(prepared) - return adj_output.vector() + adj_output = adj_output.riesz_representation(riesz_map="l2") + return adj_output else: # Linear combination expr, adj_input_func = prepared @@ -77,7 +79,9 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, expr, block_variable.saved_output, adj_input_func ) ) - adj_output.assign(diff_expr) + # Firedrake does not support assignment of conjugate functions + adj_output.interpolate(ufl.conj(diff_expr)) + adj_output = adj_output.riesz_representation(riesz_map="l2") else: mesh = adj_output.function_space().mesh() diff_expr = ufl.algorithms.expand_derivatives( @@ -88,7 +92,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, ) ) adj_output.assign(diff_expr) - return adj_output.vector().inner(adj_input_func.vector()) + return adj_output.dat.inner(adj_input_func.dat) if self.compat.isconstant(block_variable.output): R = block_variable.output._ad_function_space( @@ -96,23 +100,23 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, ) return self._adj_assign_constant(adj_output, R) else: - return adj_output.vector() + return adj_output def _adj_assign_constant(self, adj_output, constant_fs): r = self.backend.Function(constant_fs) shape = r.ufl_shape if shape == () or shape[0] == 1: # Scalar Constant - r.vector()[:] = adj_output.vector().sum() + r.dat.data[:] = adj_output.dat.data_ro.sum() else: # We assume the shape of the constant == shape of the output # function if not scalar. This assumption is due to FEniCS not # supporting products with non-scalar constants in assign. values = [] for i in range(shape[0]): - values.append(adj_output.sub(i, deepcopy=True).vector().sum()) + values.append(adj_output.sub(i, deepcopy=True).dat.data_ro.sum()) r.assign(self.backend.Constant(values)) - return r.vector() + return r def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): if self.expr is None: @@ -133,7 +137,7 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, dudmi.assign(ufl.algorithms.expand_derivatives( ufl.derivative(expr, dep.saved_output, dep.tlm_value))) - dudm.vector().axpy(1.0, dudmi.vector()) + dudm.dat += 1.0 * dudmi.dat return dudm @@ -322,10 +326,13 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, raise NotImplementedError( "Interpolate block must have a single output" ) - dJdm = self.backend.derivative(prepared, inputs[idx]) - return self.backend.Interpolator(dJdm, self.V).interpolate( - adj_inputs[0], transpose=True - ).vector() + input = inputs[idx] + dJdm = self.backend.derivative(prepared, input) + # Get the function space from `dJdm` argument + arg, = extract_arguments(dJdm) + # Make sure to have a cofunction output + output = self.backend.Cofunction(arg.function_space().dual()) + return self.backend.Interpolator(dJdm, self.V).interpolate(adj_inputs[0], output=output, transpose=True) def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): return replace(self.expr, self._replace_map()) @@ -505,12 +512,14 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, d2exprdudu = self.backend.derivative(dexprdu, block_variable.saved_output) + # Make sure to have a cofunction output + output = self.backend.Cofunction(component.function_space()) # left multiply by dJ/dv (adj_inputs[0]) - i.e. interpolate using the # transpose operator component += self.backend.Interpolator(d2exprdudu, self.V).interpolate( - adj_inputs[0], transpose=True + adj_inputs[0], output=output, transpose=True ) - return component.vector() + return component def prepare_recompute_component(self, inputs, relevant_outputs): return replace(self.expr, self._replace_map()) @@ -535,14 +544,12 @@ def __init__(self, func, idx, ad_block_tag=None): def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): - eval_adj = self.backend.Function( - block_variable.output.function_space() - ) - if type(adj_inputs[0]) is self.backend.Function: + eval_adj = self.backend.Cofunction(block_variable.output.function_space().dual()) + if type(adj_inputs[0]) is self.backend.Cofunction: eval_adj.sub(self.idx).assign(adj_inputs[0]) else: eval_adj.sub(self.idx).assign(adj_inputs[0].function) - return eval_adj.vector() + return eval_adj def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): @@ -551,11 +558,9 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): - eval_hessian = self.backend.Function( - block_variable.output.function_space() - ) - eval_hessian.sub(self.idx).assign(hessian_inputs[0].function) - return eval_hessian.vector() + eval_hessian = self.backend.Cofunction(block_variable.output.function_space().dual()) + eval_hessian.sub(self.idx).assign(hessian_inputs[0]) + return eval_hessian def recompute_component(self, inputs, block_variable, idx, prepared): return maybe_disk_checkpoint( @@ -577,7 +582,7 @@ def __init__(self, func, idx, ad_block_tag=None): def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if idx == 0: - return adj_inputs[0].subfunctions[self.idx].vector() + return adj_inputs[0].subfunctions[self.idx] else: return adj_inputs[0] diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 9665f309b0..ed84c39ad9 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -41,7 +41,7 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs): if bcs is not None: self.bcs = Enlist(bcs) - if isinstance(self.lhs, ufl.Form) and isinstance(self.rhs, ufl.Form): + if isinstance(self.lhs, ufl.Form) and isinstance(self.rhs, (ufl.Form, ufl.Cofunction)): self.linear = True for c in self.rhs.coefficients(): self.add_dependency(c, no_duplicates=True) @@ -184,16 +184,14 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): adj_sol = self.compat.create_function(self.function_space) self.compat.linalg_solve( - dFdu, adj_sol.vector(), dJdu, *self.adj_args, **self.adj_kwargs + dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs ) adj_sol_bdy = None if compute_bdy: - adj_sol_bdy = self.compat.function_from_vector( - self.function_space, - dJdu_copy - self.compat.assemble_adjoint_value( - self.backend.action(dFdu_adj_form, adj_sol) - ) + adj_sol_bdy = self.backend.Function( + self.function_space.dual(), + dJdu_copy.dat - self.compat.assemble_adjoint_value(self.backend.action(dFdu_adj_form, adj_sol)).dat ) return adj_sol, adj_sol_bdy @@ -214,7 +212,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, trial_function = self.backend.TrialFunction( c._ad_function_space(mesh) ) - elif isinstance(c, self.backend.Function): + elif isinstance(c, (self.backend.Function, self.backend.Cofunction)): trial_function = self.backend.TrialFunction(c.function_space()) elif isinstance(c, self.compat.ExpressionType): mesh = F_form.ufl_domain().ufl_cargo() @@ -228,7 +226,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, ) return [tmp_bc] elif isinstance(c, self.compat.MeshType): - # Using CoordianteDerivative requires us to do action before + # Using CoordinateDerivative requires us to do action before # differentiating, might change in the future. F_form_tmp = self.backend.action(F_form, adj_sol) X = self.backend.SpatialCoordinate(c_rep) @@ -237,6 +235,9 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, self.backend.TestFunction(c._ad_function_space()) ) + if dFdm == 0: + return self.backend.Function(c._ad_function_space().dual()) + dFdm = self.compat.assemble_adjoint_value(dFdm, **self.assemble_kwargs) return dFdm @@ -358,8 +359,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input, b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input, d2Fdu2) dFdu_form = self.backend.adjoint(dFdu_form) - adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b, - compute_bdy) + adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b, compute_bdy) if self.adj2_cb is not None: self.adj2_cb(adj_sol2) if self.adj2_bdy_cb is not None and compute_bdy: @@ -490,7 +490,8 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, ) hessian_output = 0 if not hessian_form.empty(): - hessian_output -= self.compat.assemble_adjoint_value(hessian_form) + hessian_output = self.compat.assemble_adjoint_value(hessian_form) + hessian_output *= -1. if isinstance(c, self.compat.ExpressionType): return [(hessian_output, W)] @@ -517,9 +518,11 @@ def _forward_solve(self, lhs, rhs, func, bcs): return func def _assembled_solve(self, lhs, rhs, func, bcs, **kwargs): + rhs_func = rhs.riesz_representation(riesz_map="l2") for bc in bcs: - bc.apply(rhs) - self.backend.solve(lhs, func.vector(), rhs, **kwargs) + bc.apply(rhs_func) + rhs.assign(rhs_func.riesz_representation(riesz_map="l2")) + self.backend.solve(lhs, func, rhs, **kwargs) return func def recompute_component(self, inputs, block_variable, idx, prepared): @@ -832,12 +835,11 @@ def recompute_component(self, inputs, block_variable, idx, prepared): return maybe_disk_checkpoint(target) def _recompute_component_transpose(self, inputs): - if not isinstance(inputs[0], - (self.backend.Function, self.backend.Vector)): + if not isinstance(inputs[0], self.backend.Cofunction): raise NotImplementedError( - f"Source function must be a Function, not {type(inputs[0])}." + f"Source function must be a Cofunction, not {type(inputs[0])}." ) - out = self.backend.Function(self.source_space) + out = self.backend.Cofunction(self.source_space.dual()) tmp = self.backend.Function(self.target_space) # Adjoint of step 2 (mass is self-adjoint) self.projector.apply_massinv(tmp, inputs[0]) @@ -860,7 +862,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, raise NotImplementedError( "SupermeshProjectBlock must have a single output" ) - return self._recompute_component_transpose(adj_inputs).vector() + return self._recompute_component_transpose(adj_inputs) def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): @@ -885,7 +887,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, "SupermeshProjectBlock must have a single output" ) return self.evaluate_adj_component(inputs, hessian_inputs, - block_variable, idx).vector() + block_variable, idx) def __str__(self): target_string = f"〈{str(self.target_space.ufl_element().shortstr())}〉" diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 67bccaba66..208e70cf91 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -129,6 +129,15 @@ def wrapper(self, other, *args, **kwargs): return wrapper + @staticmethod + def _ad_not_implemented(func): + @wraps(func) + def wrapper(*args, **kwargs): + if annotate_tape(kwargs): + raise NotImplementedError("Automatic differentiation is not supported for this operation.") + return func(*args, **kwargs) + return wrapper + @staticmethod def _ad_annotate_iadd(__iadd__): @wraps(__iadd__) @@ -211,31 +220,27 @@ def _ad_create_checkpoint(self): else: return self.copy(deepcopy=True) - @no_annotations - def _ad_convert_type(self, value, options=None): - from firedrake import Function, TrialFunction, TestFunction, assemble + def _ad_convert_riesz(self, value, options=None): + from firedrake import Function, Cofunction options = {} if options is None else options riesz_representation = options.get("riesz_representation", "l2") + solver_options = options.get("solver_options", {}) + V = options.get("function_space", self.function_space()) - if riesz_representation == "l2": - return Function(self.function_space(), val=value) + if riesz_representation != "l2" and not isinstance(value, Cofunction): + raise TypeError("Expected a Cofunction") + elif not isinstance(value, (float, (Cofunction, Function))): + raise TypeError("Expected a Cofunction, Function or a float") - elif riesz_representation == "L2": - ret = Function(self.function_space()) - u = TrialFunction(self.function_space()) - v = TestFunction(self.function_space()) - M = assemble(firedrake.inner(u, v)*firedrake.dx) - firedrake.solve(M, ret, value) - return ret + if riesz_representation == "l2": + value = value.dat if isinstance(value, (Cofunction, Function)) else value + return Function(V, val=value) - elif riesz_representation == "H1": - ret = Function(self.function_space()) - u = TrialFunction(self.function_space()) - v = TestFunction(self.function_space()) - M = assemble(firedrake.inner(u, v)*firedrake.dx - + firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx) - firedrake.solve(M, ret, value) + elif riesz_representation in ("L2", "H1"): + ret = Function(V) + a = self._define_riesz_map_form(riesz_representation, V) + firedrake.solve(a == value, ret, **solver_options) return ret elif callable(riesz_representation): @@ -245,6 +250,28 @@ def _ad_convert_type(self, value, options=None): raise NotImplementedError( "Unknown Riesz representation %s" % riesz_representation) + def _define_riesz_map_form(self, riesz_representation, V): + from firedrake import TrialFunction, TestFunction + + u = TrialFunction(V) + v = TestFunction(V) + if riesz_representation == "L2": + a = firedrake.inner(u, v)*firedrake.dx + + elif riesz_representation == "H1": + a = firedrake.inner(u, v)*firedrake.dx \ + + firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx + + else: + raise NotImplementedError( + "Unknown Riesz representation %s" % riesz_representation) + return a + + @no_annotations + def _ad_convert_type(self, value, options=None): + # `_ad_convert_type` is not annoated unlike to `_ad_convert_riesz` + return self._ad_convert_riesz(value, options=options) + def _ad_restore_at_checkpoint(self, checkpoint): if isinstance(checkpoint, CheckpointBase): return checkpoint.restore() @@ -263,7 +290,8 @@ def _ad_mul(self, other): from firedrake import Function r = Function(self.function_space()) - r.assign(self * other) + # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + r.assign(other * self) return r @no_annotations @@ -280,7 +308,7 @@ def _ad_dot(self, other, options=None): options = {} if options is None else options riesz_representation = options.get("riesz_representation", "l2") if riesz_representation == "l2": - return self.vector().inner(other.vector()) + return self.dat.inner(other.dat) elif riesz_representation == "L2": return assemble(firedrake.inner(self, other)*firedrake.dx) elif riesz_representation == "H1": diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index f1e800ad45..ff16ce6b25 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -24,7 +24,7 @@ def wrapper(self, *args, **kwargs): self.u, TrialFunction(self.u.function_space())) self._ad_adj_F = adjoint(dFdu) - except TypeError: + except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} self._ad_count_map = {} diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 65a188c83b..124d1f29d0 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -14,12 +14,12 @@ from tsfc.finatinterface import create_element from tsfc.ufl_utils import extract_firedrake_constants import ufl -from ufl.domain import extract_unique_domain from firedrake import (extrusion_utils as eutils, matrix, parameters, solving, tsfc_interface, utils) from firedrake.adjoint_utils import annotate_assemble +from firedrake.ufl_expr import extract_unique_domain from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit -from firedrake.functionspaceimpl import WithGeometry, FunctionSpace +from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key from firedrake.petsc import PETSc from firedrake.slate import slac, slate @@ -42,7 +42,7 @@ def assemble(expr, *args, **kwargs): r"""Evaluate expr. - :arg expr: a :class:`~ufl.classes.Form`, :class:`~ufl.classes.Expr` or + :arg expr: a :class:`~ufl.classes.BaseForm`, :class:`~ufl.classes.Expr` or a :class:`~.slate.TensorBase` expression. :arg tensor: Existing tensor object to place the result in. :arg bcs: Iterable of boundary conditions to apply. @@ -98,20 +98,266 @@ def assemble(expr, *args, **kwargs): will be set to 0 and the diagonal entries to 1. If ``expr`` is a 1-form, the vector entries at boundary nodes are set to the boundary condition values. - - .. note:: - For 1-form assembly, the resulting object should in fact be a *cofunction* - instead of a :class:`.Function`. However, since cofunctions are not - currently supported in UFL, functions are used instead. """ - if isinstance(expr, (ufl.form.Form, slate.TensorBase)): - return _assemble_form(expr, *args, **kwargs) + if isinstance(expr, (ufl.form.BaseForm, slate.TensorBase)): + return assemble_base_form(expr, *args, **kwargs) elif isinstance(expr, ufl.core.expr.Expr): return _assemble_expr(expr) else: raise TypeError(f"Unable to assemble: {expr}") +def assemble_base_form(expression, tensor=None, bcs=None, + diagonal=False, + mat_type=None, + sub_mat_type=None, + form_compiler_parameters=None, + appctx=None, + options_prefix=None, + zero_bc_nodes=False, + is_base_form_preprocessed=False, + weight=1.0): + r"""Evaluate expression. + + :arg expression: a :class:`~ufl.classes.BaseForm` + :kwarg tensor: Existing tensor object to place the result in. + :kwarg bcs: Iterable of boundary conditions to apply. + :kwarg diagonal: If assembling a matrix is it diagonal? + :kwarg mat_type: String indicating how a 2-form (matrix) should be + assembled -- either as a monolithic matrix (``"aij"`` or ``"baij"``), + a block matrix (``"nest"``), or left as a :class:`.ImplicitMatrix` giving + matrix-free actions (``'matfree'``). If not supplied, the default value in + ``parameters["default_matrix_type"]`` is used. BAIJ differs + from AIJ in that only the block sparsity rather than the dof + sparsity is constructed. This can result in some memory + savings, but does not work with all PETSc preconditioners. + BAIJ matrices only make sense for non-mixed matrices. + :kwarg sub_mat_type: String indicating the matrix type to + use *inside* a nested block matrix. Only makes sense if + ``mat_type`` is ``nest``. May be one of ``"aij"`` or ``"baij"``. If + not supplied, defaults to ``parameters["default_sub_matrix_type"]``. + :kwarg form_compiler_parameters: Dictionary of parameters to pass to + the form compiler. Ignored if not assembling a :class:`~ufl.classes.Form`. + Any parameters provided here will be overridden by parameters set on the + :class:`~ufl.classes.Measure` in the form. For example, if a + ``quadrature_degree`` of 4 is specified in this argument, but a degree of + 3 is requested in the measure, the latter will be used. + :kwarg appctx: Additional information to hang on the assembled + matrix if an implicit matrix is requested (mat_type ``"matfree"``). + :kwarg options_prefix: PETSc options prefix to apply to matrices. + :kwarg zero_bc_nodes: If ``True``, set the boundary condition nodes in the + output tensor to zero rather than to the values prescribed by the + boundary condition. Default is ``False``. + :kwarg is_base_form_preprocessed: If ``True``, skip preprocessing of the form. + :kwarg weight: weight of the boundary condition, i.e. the scalar in front of the + identity matrix corresponding to the boundary nodes. + To discretise eigenvalue problems set the weight equal to 0.0. + + :returns: a :class:`float` for 0-forms, a :class:`.Cofunction` or a :class:`.Function` for 1-forms, + and a :class:`.MatrixBase` for 2-forms. + + This function assembles a :class:`~ufl.classes.BaseForm` object by traversing the corresponding DAG + in a post-order fashion and evaluating the nodes on the fly. + """ + + # Preprocess the DAG and restructure the DAG + if not is_base_form_preprocessed and not isinstance(expression, slate.TensorBase): + # Preprocessing the form makes a new object -> current form caching mechanism + # will populate `expr`'s cache which is now different than `expression`'s cache so we need + # to transmit the cache. All of this only holds when `expression` is a `ufl.Form` + # and therefore when `is_base_form_preprocessed` is False. + expr = preprocess_base_form(expression, mat_type, form_compiler_parameters) + if isinstance(expression, ufl.form.Form) and isinstance(expr, ufl.form.Form): + expr._cache = expression._cache + else: + expr = expression + + # DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly. + stack = [expr] + visited = {} + while stack: + e = stack.pop() + unvisted_children = [] + operands = base_form_operands(e) + for arg in operands: + if arg not in visited: + unvisted_children.append(arg) + + if unvisted_children: + stack.append(e) + stack.extend(unvisted_children) + else: + visited[e] = base_form_assembly_visitor(e, tensor, bcs, diagonal, + form_compiler_parameters, + mat_type, sub_mat_type, + appctx, options_prefix, + zero_bc_nodes, weight, + *(visited[arg] for arg in operands)) + if tensor: + update_tensor(visited[expr], tensor) + return visited[expr] + + +def update_tensor(assembled_base_form, tensor): + if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): + assembled_base_form.dat.copy(tensor.dat) + elif isinstance(tensor, matrix.MatrixBase): + # Uses the PETSc copy method. + assembled_base_form.petscmat.copy(tensor.petscmat) + else: + raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) + + +def base_form_operands(expr): + if isinstance(expr, (ufl.form.FormSum, ufl.Adjoint, ufl.Action)): + return list(expr.ufl_operands) + return [] + + +def expand_derivatives_form(form, fc_params): + """Expand derivatives of ufl.BaseForm objects + :arg form: a :class:`~ufl.classes.BaseForm` + :arg fc_params:: Dictionary of parameters to pass to the form compiler. + + :returns: The resulting preprocessed :class:`~ufl.classes.BaseForm`. + This function preprocess the form, mainly by expanding the derivatives, in order to determine + if we are dealing with a :class:`~ufl.classes.Form` or another :class:`~ufl.classes.BaseForm` object. + This function is called in :func:`base_form_assembly_visitor`. Depending on the type of the resulting tensor, + we may call :func:`assemble_form` or traverse the sub-DAG via :func:`assemble_base_form`. + """ + if isinstance(form, ufl.form.Form): + from firedrake.parameters import parameters as default_parameters + from tsfc.parameters import is_complex + + if fc_params is None: + fc_params = default_parameters["form_compiler"].copy() + else: + # Override defaults with user-specified values + _ = fc_params + fc_params = default_parameters["form_compiler"].copy() + fc_params.update(_) + + complex_mode = fc_params and is_complex(fc_params.get("scalar_type")) + + return ufl.algorithms.preprocess_form(form, complex_mode) + # We also need to expand derivatives for `ufl.BaseForm` objects that are not `ufl.Form` + # Example: `Action(A, derivative(B, f))`, where `A` is a `ufl.BaseForm` and `B` can + # be `ufl.BaseForm`, or even an appropriate `ufl.Expr`, since assembly of expressions + # containing derivatives is not supported anymore but might be needed if the expression + # in question is within a `ufl.BaseForm` object. + return ufl.algorithms.ad.expand_derivatives(form) + + +def preprocess_base_form(expr, mat_type=None, form_compiler_parameters=None): + """Preprocess ufl.BaseForm objects""" + if mat_type != "matfree": + # For "matfree", Form evaluation is delayed + expr = expand_derivatives_form(expr, form_compiler_parameters) + # Expanding derivatives may turn `ufl.BaseForm` objects into `ufl.Expr` objects that are not `ufl.BaseForm`. + if not isinstance(expr, ufl.form.BaseForm): + return assemble(expr) + return expr + + +def base_form_assembly_visitor(expr, tensor, bcs, diagonal, + form_compiler_parameters, + mat_type, sub_mat_type, + appctx, options_prefix, + zero_bc_nodes, weight, *args): + r"""Assemble a :class:`~ufl.classes.BaseForm` object given its assembled operands. + + This functions contains the assembly handlers corresponding to the different nodes that + can arise in a `~ufl.classes.BaseForm` object. It is called by :func:`assemble_base_form` + in a post-order fashion. + """ + + if isinstance(expr, (ufl.form.Form, slate.TensorBase)): + return _assemble_form(expr, tensor=tensor, bcs=bcs, + diagonal=diagonal, + mat_type=mat_type, + sub_mat_type=sub_mat_type, + appctx=appctx, + options_prefix=options_prefix, + form_compiler_parameters=form_compiler_parameters, + zero_bc_nodes=zero_bc_nodes, weight=weight) + + elif isinstance(expr, ufl.Adjoint): + if len(args) != 1: + raise TypeError("Not enough operands for Adjoint") + mat, = args + petsc_mat = mat.petscmat + petsc_mat.hermitianTranspose() + (row, col) = mat.arguments() + return matrix.AssembledMatrix((col, row), bcs, petsc_mat, + appctx=appctx, + options_prefix=options_prefix) + elif isinstance(expr, ufl.Action): + if (len(args) != 2): + raise TypeError("Not enough operands for Action") + lhs, rhs = args + if isinstance(lhs, matrix.MatrixBase): + if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)): + petsc_mat = lhs.petscmat + (row, col) = lhs.arguments() + res = firedrake.Cofunction(col.function_space().dual()) + + with rhs.dat.vec_ro as v_vec: + with res.dat.vec as res_vec: + petsc_mat.mult(v_vec, res_vec) + return firedrake.Cofunction(row.function_space().dual(), val=res.dat) + elif isinstance(rhs, matrix.MatrixBase): + petsc_mat = lhs.petscmat + (row, col) = lhs.arguments() + res = petsc_mat.matMult(rhs.petscmat) + return matrix.AssembledMatrix(expr, bcs, res, + appctx=appctx, + options_prefix=options_prefix) + else: + raise TypeError("Incompatible RHS for Action.") + elif isinstance(lhs, (firedrake.Cofunction, firedrake.Function)): + if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)): + # Return scalar value + with lhs.dat.vec_ro as x, rhs.dat.vec_ro as y: + res = x.dot(y) + return res + else: + raise TypeError("Incompatible RHS for Action.") + else: + raise TypeError("Incompatible LHS for Action.") + elif isinstance(expr, ufl.FormSum): + if len(args) != len(expr.weights()): + raise TypeError("Mismatching weights and operands in FormSum") + if len(args) == 0: + raise TypeError("Empty FormSum") + if all(isinstance(op, float) for op in args): + return sum(args) + elif all(isinstance(op, firedrake.Cofunction) for op in args): + V, = set(a.function_space() for a in args) + res = sum([w*op.dat for (op, w) in zip(args, expr.weights())]) + return firedrake.Cofunction(V, res) + elif all(isinstance(op, ufl.Matrix) for op in args): + is_set = False + for (op, w) in zip(args, expr.weights()): + petsc_mat = op.petscmat + petsc_mat.scale(w) + if is_set: + res = res + petsc_mat + else: + res = petsc_mat + is_set = True + return matrix.AssembledMatrix(expr, bcs, res, + appctx=appctx, + options_prefix=options_prefix) + else: + raise TypeError("Mismatching FormSum shapes") + elif isinstance(expr, (ufl.Cofunction, ufl.Coargument, ufl.Argument, ufl.Matrix, ufl.ZeroBaseForm)): + return expr + elif isinstance(expr, ufl.Coefficient): + return expr + else: + raise TypeError(f"Unrecognised BaseForm instance: {expr}") + + @PETSc.Log.EventDecorator() def allocate_matrix(expr, bcs=None, *, mat_type=None, sub_mat_type=None, appctx=None, form_compiler_parameters=None, options_prefix=None): @@ -372,10 +618,10 @@ def _make_tensor(form, bcs, *, diagonal, mat_type, sub_mat_type, appctx, ) elif rank == 1: test, = form.arguments() - return firedrake.Function(test.function_space()) + return firedrake.Cofunction(test.function_space().dual()) elif rank == 2 and diagonal: test, _ = form.arguments() - return firedrake.Function(test.function_space()) + return firedrake.Cofunction(test.function_space().dual()) elif rank == 2: mat_type, sub_mat_type = _get_mat_type(mat_type, sub_mat_type, form.arguments()) return allocate_matrix(form, bcs, mat_type=mat_type, sub_mat_type=sub_mat_type, @@ -527,7 +773,7 @@ def collect_lgmaps(self, local_knl, bcs): def _as_pyop2_type(tensor): if isinstance(tensor, op2.Global): return tensor - elif isinstance(tensor, firedrake.Function): + elif isinstance(tensor, firedrake.Cofunction): return tensor.dat elif isinstance(tensor, matrix.Matrix): return tensor.M @@ -595,10 +841,12 @@ def _apply_bc(self, bc): def _apply_dirichlet_bc(self, bc): if not self._zero_bc_nodes: + tensor_func = self._tensor.riesz_representation(riesz_map="l2") if self._diagonal: - bc.set(self._tensor, 1) + bc.set(tensor_func, 1) else: - bc.apply(self._tensor) + bc.apply(tensor_func) + self._tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) else: bc.zero(self._tensor) @@ -732,6 +980,31 @@ def assemble(self): return self._tensor +def get_form_assembler(form, tensor, *args, **kwargs): + """Provide the assemble method for `form`""" + + # Don't expand derivatives if `mat_type` is 'matfree' + mat_type = kwargs.pop('mat_type', None) + if not isinstance(form, slate.TensorBase): + fc_params = kwargs.get('form_compiler_parameters') + # Only pre-process `form` once beforehand to avoid pre-processing for each assembly call + form = preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) + + if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not base_form_operands(form): + diagonal = kwargs.pop('diagonal', False) + if len(form.arguments()) == 1 or diagonal: + return OneFormAssembler(form, tensor, *args, diagonal=diagonal, **kwargs).assemble + elif len(form.arguments()) == 2: + return TwoFormAssembler(form, tensor, *args, **kwargs).assemble + else: + raise ValueError('Expecting a 1-form or 2-form and not %s' % (form)) + elif isinstance(form, ufl.form.BaseForm): + return functools.partial(assemble_base_form, form, *args, tensor=tensor, mat_type=mat_type, + is_base_form_preprocessed=True, **kwargs) + else: + raise ValueError('Expecting a BaseForm or a slate.TensorBase object and not %s' % form) + + def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomain_ids, **kwargs): # N.B. Generating the global kernel is not a collective operation so the # communicator does not need to be a part of this cache key. @@ -1135,7 +1408,7 @@ def _iterset(self): def _get_map(self, V): """Return the appropriate PyOP2 map for a given function space.""" - assert isinstance(V, (WithGeometry, FunctionSpace)) + assert isinstance(V, (WithGeometry, FiredrakeDualSpace, FunctionSpace)) if self._integral_type in {"cell", "exterior_facet_top", "exterior_facet_bottom", "interior_facet_horiz"}: diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py new file mode 100644 index 0000000000..56882b7379 --- /dev/null +++ b/firedrake/cofunction.py @@ -0,0 +1,302 @@ +import numpy as np +import ufl +from ufl.form import BaseForm +from pyop2 import op2, mpi +import firedrake.assemble +import firedrake.functionspaceimpl as functionspaceimpl +from firedrake import utils, vector, ufl_expr +from firedrake.utils import ScalarType +from firedrake.adjoint_utils.function import FunctionMixin + + +class Cofunction(ufl.Cofunction, FunctionMixin): + r"""A :class:`Cofunction` represents a function on a dual space. + Like Functions, cofunctions are + represented as sums of basis functions: + + .. math:: + + f = \\sum_i f_i \phi_i(x) + + The :class:`Cofunction` class provides storage for the coefficients + :math:`f_i` and associates them with a :class:`.FunctionSpace` object + which provides the basis functions :math:`\\phi_i(x)`. + + Note that the coefficients are always scalars: if the + :class:`Cofunction` is vector-valued then this is specified in + the :class:`.FunctionSpace`. + """ + + @FunctionMixin._ad_annotate_init + def __init__(self, function_space, val=None, name=None, dtype=ScalarType): + r""" + :param function_space: the :class:`.FunctionSpace`, + or :class:`.MixedFunctionSpace` on which to build this :class:`Cofunction`. + Alternatively, another :class:`Cofunction` may be passed here and its function space + will be used to build this :class:`Cofunction`. In this + case, the function values are copied. + :param val: NumPy array-like (or :class:`pyop2.types.dat.Dat`) providing initial values (optional). + If val is an existing :class:`Cofunction`, then the data will be shared. + :param name: user-defined name for this :class:`Cofunction` (optional). + :param dtype: optional data type for this :class:`Cofunction` + (defaults to ``ScalarType``). + """ + + V = function_space + if isinstance(V, Cofunction): + V = V.function_space() + # Deep copy prevents modifications to Vector copies. + # Also, this discard the value of `val` if it was specified (consistent with Function) + val = function_space.copy(deepcopy=True).dat + elif not isinstance(V, functionspaceimpl.FiredrakeDualSpace): + raise NotImplementedError("Can't make a Cofunction defined on a " + + str(type(function_space))) + + ufl.Cofunction.__init__(self, V.ufl_function_space()) + + # User comm + self.comm = V.comm + # Internal comm + self._comm = mpi.internal_comm(V.comm) + self._function_space = V + self.uid = utils._new_uid() + self._name = name or 'cofunction_%d' % self.uid + self._label = "a cofunction" + + if isinstance(val, vector.Vector): + # Allow constructing using a vector. + val = val.dat + if isinstance(val, (op2.Dat, op2.DatView, op2.MixedDat, op2.Global)): + assert val.comm == self._comm + self.dat = val + else: + self.dat = function_space.make_dat(val, dtype, self.name()) + + if isinstance(function_space, Cofunction): + self.dat.copy(function_space.dat) + + def __del__(self): + if hasattr(self, "_comm"): + mpi.decref(self._comm) + + def copy(self, deepcopy=True): + r"""Return a copy of this :class:`firedrake.function.CoordinatelessFunction`. + + :kwarg deepcopy: If ``True``, the default, the new + :class:`firedrake.function.CoordinatelessFunction` will allocate new space + and copy values. If ``False``, then the new + :class:`firedrake.function.CoordinatelessFunction` will share the dof values. + """ + if deepcopy: + val = type(self.dat)(self.dat) + else: + val = self.dat + return type(self)(self.function_space(), + val=val, name=self.name(), + dtype=self.dat.dtype) + + def _analyze_form_arguments(self): + # Cofunctions have one argument in primal space as they map from V to R. + self._arguments = (ufl_expr.Argument(self.function_space().dual(), 0),) + self._coefficients = (self,) + + @utils.cached_property + @FunctionMixin._ad_annotate_subfunctions + def subfunctions(self): + r"""Extract any sub :class:`Cofunction`\s defined on the component spaces + of this this :class:`Cofunction`'s :class:`.FunctionSpace`.""" + return tuple(type(self)(fs, dat) for fs, dat in zip(self.function_space(), self.dat)) + + @FunctionMixin._ad_annotate_subfunctions + def split(self): + import warnings + warnings.warn("The .split() method is deprecated, please use the .subfunctions property instead", category=FutureWarning) + return self.subfunctions + + @utils.cached_property + def _components(self): + if self.function_space().value_size == 1: + return (self, ) + else: + return tuple(type(self)(self.function_space().sub(i), val=op2.DatView(self.dat, i)) + for i in range(self.function_space().value_size)) + + def sub(self, i): + r"""Extract the ith sub :class:`Cofunction` of this :class:`Cofunction`. + + :arg i: the index to extract + + See also :attr:`subfunctions`. + + If the :class:`Cofunction` is defined on a + :func:`~.VectorFunctionSpace` or :func:`~.TensorFunctionSpace` + this returns a proxy object indexing the ith component of the space, + suitable for use in boundary condition application.""" + if len(self.function_space()) == 1: + return self._components[i] + return self.subfunctions[i] + + def function_space(self): + r"""Return the :class:`.FunctionSpace`, or :class:`.MixedFunctionSpace` + on which this :class:`Cofunction` is defined. + """ + return self._function_space + + @FunctionMixin._ad_not_implemented + @utils.known_pyop2_safe + def assign(self, expr, subset=None): + r"""Set the :class:`Cofunction` value to the pointwise value of + expr. expr may only contain :class:`Cofunction`\s on the same + :class:`.FunctionSpace` as the :class:`Cofunction` being assigned to. + + Similar functionality is available for the augmented assignment + operators `+=`, `-=`, `*=` and `/=`. For example, if `f` and `g` are + both Cofunctions on the same :class:`.FunctionSpace` then:: + + f += 2 * g + + will add twice `g` to `f`. + + If present, subset must be an :class:`pyop2.types.set.Subset` of this + :class:`Cofunction`'s ``node_set``. The expression will then + only be assigned to the nodes on that subset. + """ + expr = ufl.as_ufl(expr) + if isinstance(expr, ufl.classes.Zero): + self.dat.zero(subset=subset) + return self + elif (isinstance(expr, Cofunction) + and expr.function_space() == self.function_space()): + expr.dat.copy(self.dat, subset=subset) + return self + elif isinstance(expr, BaseForm): + # Enable to write down c += B where c is a Cofunction + # and B an appropriate BaseForm object + assembled_expr = firedrake.assemble(expr) + return self.assign(assembled_expr) + + raise ValueError('Cannot assign %s' % expr) + + def riesz_representation(self, riesz_map='L2', **solver_options): + """Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. + + Example: For a L2 Riesz map, the Riesz representation is obtained by solving + the linear system ``Mx = self``, where M is the L2 mass matrix, i.e. M = + with u and v trial and test functions, respectively. + + Parameters + ---------- + riesz_map : str or collections.abc.Callable + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + solver_options : dict + Solver options to pass to the linear solver: + - solver_parameters: optional solver parameters. + - nullspace: an optional :class:`.VectorSpaceBasis` (or :class:`.MixedVectorSpaceBasis`) + spanning the null space of the operator. + - transpose_nullspace: as for the nullspace, but used to make the right hand side consistent. + - near_nullspace: as for the nullspace, but used to add the near nullspace. + - options_prefix: an optional prefix used to distinguish PETSc options. + If not provided a unique prefix will be created. + Use this option if you want to pass options to the solver from the command line + in addition to through the ``solver_parameters`` dict. + + Returns + ------- + firedrake.function.Function + Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. + """ + return self._ad_convert_riesz(self, options={"function_space": self.function_space().dual(), + "riesz_representation": riesz_map, + "solver_options": solver_options}) + + @FunctionMixin._ad_annotate_iadd + @utils.known_pyop2_safe + def __iadd__(self, expr): + + if np.isscalar(expr): + self.dat += expr + return self + if isinstance(expr, vector.Vector): + expr = expr.function + if isinstance(expr, Cofunction) and \ + expr.function_space() == self.function_space(): + self.dat += expr.dat + return self + # Let Python hit `BaseForm.__add__` which relies on ufl.FormSum. + return NotImplemented + + @FunctionMixin._ad_annotate_isub + @utils.known_pyop2_safe + def __isub__(self, expr): + + if np.isscalar(expr): + self.dat -= expr + return self + if isinstance(expr, vector.Vector): + expr = expr.function + if isinstance(expr, Cofunction) and \ + expr.function_space() == self.function_space(): + self.dat -= expr.dat + return self + + # Let Python hit `BaseForm.__sub__` which relies on ufl.FormSum. + return NotImplemented + + @FunctionMixin._ad_annotate_imul + def __imul__(self, expr): + + if np.isscalar(expr): + self.dat *= expr + return self + if isinstance(expr, vector.Vector): + expr = expr.function + if isinstance(expr, Cofunction) and \ + expr.function_space() == self.function_space(): + self.dat *= expr.dat + return self + return NotImplemented + + def vector(self): + r"""Return a :class:`.Vector` wrapping the data in this + :class:`Cofunction`""" + return vector.Vector(self) + + @property + def node_set(self): + r"""A :class:`pyop2.types.set.Set` containing the nodes of this + :class:`Cofunction`. One or (for rank-1 and 2 + :class:`.FunctionSpace`\s) more degrees of freedom are stored + at each node. + """ + return self.function_space().node_set + + def ufl_id(self): + return self.uid + + def name(self): + r"""Return the name of this :class:`Cofunction`""" + return self._name + + def label(self): + r"""Return the label (a description) of this :class:`Cofunction`""" + return self._label + + def rename(self, name=None, label=None): + r"""Set the name and or label of this :class:`Cofunction` + + :arg name: The new name of the `Cofunction` (if not `None`) + :arg label: The new label for the `Cofunction` (if not `None`) + """ + if name is not None: + self._name = name + if label is not None: + self._label = label + + def __str__(self): + if self._name is not None: + return self._name + else: + return super(Cofunction, self).__str__() + + def cell_node_map(self): + return self.function_space().cell_node_map() diff --git a/firedrake/function.py b/firedrake/function.py index e2b50b1471..231ada423e 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -1,6 +1,7 @@ import numpy as np import sys import ufl +from ufl.duals import is_dual from ufl.formatting.ufl2unicode import ufl2unicode from ufl.domain import extract_unique_domain import cachetools @@ -14,6 +15,7 @@ from firedrake.utils import ScalarType, IntType, as_ctypes from firedrake import functionspaceimpl +from firedrake.cofunction import Cofunction from firedrake import utils from firedrake import vector from firedrake.adjoint_utils import FunctionMixin @@ -233,6 +235,11 @@ class Function(ufl.Coefficient, FunctionMixin): the :class:`.FunctionSpace`. """ + def __new__(cls, *args, **kwargs): + if args[0] and is_dual(args[0]): + return Cofunction(*args, **kwargs) + return super().__new__(cls, *args, **kwargs) + @PETSc.Log.EventDecorator() @FunctionMixin._ad_annotate_init def __init__(self, function_space, val=None, name=None, dtype=ScalarType, @@ -458,6 +465,41 @@ def assign(self, expr, subset=None): Assigner(self, expr, subset).assign() return self + def riesz_representation(self, riesz_map='L2'): + """Return the Riesz representation of this :class:`Function` with respect to the given Riesz map. + + Example: For a L2 Riesz map, the Riesz representation is obtained by taking the action + of ``M`` on ``self``, where M is the L2 mass matrix, i.e. M = + with u and v trial and test functions, respectively. + + Parameters + ---------- + riesz_map : str or collections.abc.Callable + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + + Returns + ------- + firedrake.cofunction.Cofunction + Riesz representation of this :class:`Function` with respect to the given Riesz map. + """ + from firedrake.ufl_expr import action + from firedrake.assemble import assemble + + V = self.function_space() + if riesz_map == "l2": + return Cofunction(V.dual(), val=self.dat) + + elif riesz_map in ("L2", "H1"): + a = self._define_riesz_map_form(riesz_map, V) + return assemble(action(a, self)) + + elif callable(riesz_map): + return riesz_map(self) + + else: + raise NotImplementedError( + "Unknown Riesz representation %s" % riesz_map) + @FunctionMixin._ad_annotate_iadd def __iadd__(self, expr): from firedrake.assign import IAddAssigner diff --git a/firedrake/functionspace.py b/firedrake/functionspace.py index 6cad5857b6..19b8321f46 100644 --- a/firedrake/functionspace.py +++ b/firedrake/functionspace.py @@ -139,6 +139,45 @@ def FunctionSpace(mesh, family, degree=None, name=None, vfamily=None, return new +@PETSc.Log.EventDecorator() +def DualSpace(mesh, family, degree=None, name=None, vfamily=None, + vdegree=None): + """Create a :class:`.FunctionSpace`. + + :arg mesh: The mesh to determine the cell from. + :arg family: The finite element family. + :arg degree: The degree of the finite element. + :arg name: An optional name for the function space. + :arg vfamily: The finite element in the vertical dimension + (extruded meshes only). + :arg vdegree: The degree of the element in the vertical dimension + (extruded meshes only). + + The ``family`` argument may be an existing + :class:`ufl.FiniteElementBase`, in which case all other arguments + are ignored and the appropriate :class:`.FunctionSpace` is returned. + """ + element = make_scalar_element(mesh, family, degree, vfamily, vdegree) + + # Support FunctionSpace(mesh, MixedElement) + if type(element) is ufl.MixedElement: + return MixedFunctionSpace(element, mesh=mesh, name=name) + + # Check that any Vector/Tensor/Mixed modifiers are outermost. + check_element(element) + + # Otherwise, build the FunctionSpace. + topology = mesh.topology + if element.family() == "Real": + new = impl.RealFunctionSpace(topology, element, name=name) + else: + new = impl.FunctionSpace(topology, element, name=name) + if mesh is not topology: + return impl.FiredrakeDualSpace.create(new, mesh) + else: + return new + + @PETSc.Log.EventDecorator() def VectorFunctionSpace(mesh, family, degree=None, dim=None, name=None, vfamily=None, vdegree=None): diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index a29a384685..fe28e47300 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -19,18 +19,18 @@ from firedrake.petsc import PETSc -class WithGeometry(ufl.FunctionSpace): +class WithGeometryBase(object): r"""Attach geometric information to a :class:`~.FunctionSpace`. Function spaces on meshes with different geometry but the same topology can share data, except for their UFL cell. This class facilitates that. - Users should not instantiate a :class:`WithGeometry` object + Users should not instantiate a :class:`WithGeometryBase` object explicitly except in a small number of cases. - When instantiating a :class:`WithGeometry`, users should call - :meth:`WithGeometry.create` rather than ``__init__``. + When instantiating a :class:`WithGeometryBase`, users should call + :meth:`WithGeometryBase.create` rather than ``__init__``. :arg mesh: The mesh with geometric information to use. :arg element: The UFL element. @@ -72,7 +72,7 @@ def create(cls, function_space, mesh): component = function_space.component if function_space.parent is not None: - parent = WithGeometry.create(function_space.parent, mesh) + parent = cls.create(function_space.parent, mesh) else: parent = None @@ -102,7 +102,7 @@ def topological(self, val): @utils.cached_property def subfunctions(self): r"""Split into a tuple of constituent spaces.""" - return tuple(WithGeometry.create(subspace, self.mesh()) + return tuple(type(self).create(subspace, self.mesh()) for subspace in self.topological.subfunctions) mesh = ufl.FunctionSpace.ufl_domain @@ -128,7 +128,7 @@ def split(self): @utils.cached_property def _components(self): if len(self) == 1: - return tuple(WithGeometry.create(self.topological.sub(i), self.mesh()) + return tuple(type(self).create(self.topological.sub(i), self.mesh()) for i in range(self.value_size)) else: return self.subfunctions @@ -265,10 +265,10 @@ def __len__(self): return len(self.topological) def __repr__(self): - return "WithGeometry(%r, %r)" % (self.topological, self.mesh()) + return "%s(%r, %r)" % (self.__class__.__name__, self.topological, self.mesh()) def __str__(self): - return "WithGeometry(%s, %s)" % (self.topological, self.mesh()) + return "%s(%s, %s)" % (self.__class__.__name__, self.topological, self.mesh()) def __iter__(self): return iter(self.subfunctions) @@ -288,11 +288,11 @@ def __getattr__(self, name): return val def __dir__(self): - current = super(WithGeometry, self).__dir__() + current = super(type(self), self).__dir__() return list(OrderedDict.fromkeys(dir(self.topological) + current)) def boundary_nodes(self, sub_domain): - r"""Return the boundary nodes for this :class:`~.WithGeometry`. + r"""Return the boundary nodes for this :class:`~.WithGeometryBase`. :arg sub_domain: the mesh marker selecting which subset of facets to consider. :returns: A numpy array of the unique function space nodes on @@ -308,6 +308,28 @@ def collapse(self): return type(self).create(self.topological.collapse(), self.mesh()) +class WithGeometry(WithGeometryBase, ufl.FunctionSpace): + + def __init__(self, mesh, element, component=None, cargo=None): + super(WithGeometry, self).__init__(mesh, element, + component=component, + cargo=cargo) + + def dual(self): + return FiredrakeDualSpace.create(self.topological, self.mesh()) + + +class FiredrakeDualSpace(WithGeometryBase, ufl.functionspace.DualSpace): + + def __init__(self, mesh, element, component=None, cargo=None): + super(FiredrakeDualSpace, self).__init__(mesh, element, + component=component, + cargo=cargo) + + def dual(self): + return WithGeometry.create(self.topological, self.mesh()) + + class FunctionSpace(object): r"""A representation of a function space. @@ -1040,7 +1062,7 @@ def local_to_global_map(self, bcs, lgmap=None): @dataclass class FunctionSpaceCargo: - """Helper class carrying data for a :class:`WithGeometry`. + """Helper class carrying data for a :class:`WithGeometryBase`. It is required because it permits Firedrake to have stripped forms that still know Firedrake-specific information (e.g. that they are a @@ -1048,4 +1070,4 @@ class FunctionSpaceCargo: """ topological: FunctionSpace - parent: Optional[WithGeometry] + parent: Optional[WithGeometryBase] diff --git a/firedrake/linear_solver.py b/firedrake/linear_solver.py index 7af6157fe7..cc6bf8d242 100644 --- a/firedrake/linear_solver.py +++ b/firedrake/linear_solver.py @@ -1,5 +1,6 @@ from firedrake.exceptions import ConvergenceError import firedrake.function as function +import firedrake.cofunction as cofunction import firedrake.vector as vector import firedrake.matrix as matrix import firedrake.solving_utils as solving_utils @@ -52,7 +53,7 @@ def __init__(self, A, *, P=None, solver_parameters=None, solver_parameters = flatten_parameters(solver_parameters or {}) solver_parameters = solving_utils.set_defaults(solver_parameters, - A.a.arguments(), + A.arguments(), ksp_defaults=self.DEFAULT_KSP_PARAMETERS) self.A = A self.comm = A.comm @@ -107,18 +108,18 @@ def __del__(self): @cached_property def test_space(self): - return self.A.a.arguments()[0].function_space() + return self.A.arguments()[0].function_space() @cached_property def trial_space(self): - return self.A.a.arguments()[1].function_space() + return self.A.arguments()[1].function_space() @cached_property def _rhs(self): from firedrake.assemble import OneFormAssembler u = function.Function(self.trial_space) - b = function.Function(self.test_space) + b = cofunction.Cofunction(self.test_space.dual()) expr = -action(self.A.a, u) return u, OneFormAssembler(expr, tensor=b).assemble, b @@ -130,19 +131,25 @@ def _lifted(self, b): update() # blift contains -A u_bc blift += b - for bc in self.A.bcs: - bc.apply(blift) + if isinstance(blift, cofunction.Cofunction): + blift_func = blift.riesz_representation(riesz_map="l2") + for bc in self.A.bcs: + bc.apply(blift_func) + blift.assign(blift_func.riesz_representation(riesz_map="l2")) + else: + for bc in self.A.bcs: + bc.apply(blift) # blift is now b - A u_bc, and satisfies the boundary conditions return blift @PETSc.Log.EventDecorator() def solve(self, x, b): - if not isinstance(x, (function.Function, vector.Vector)): - raise TypeError("Provided solution is a '%s', not a Function or Vector" % type(x).__name__) + if not isinstance(x, (function.Function, vector.Vector, cofunction.Cofunction)): + raise TypeError("Provided solution is a '%s', not a Function, Vector or Cofunction" % type(x).__name__) if isinstance(b, vector.Vector): b = b.function - if not isinstance(b, function.Function): - raise TypeError("Provided RHS is a '%s', not a Function" % type(b).__name__) + if not isinstance(b, (function.Function, cofunction.Cofunction)): + raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__) if len(self.trial_space) > 1 and self.nullspace is not None: self.nullspace._apply(self.trial_space.dof_dset.field_ises) diff --git a/firedrake/matrix.py b/firedrake/matrix.py index ef8e640ab3..6c9ade0c1f 100644 --- a/firedrake/matrix.py +++ b/firedrake/matrix.py @@ -1,18 +1,20 @@ -import abc import itertools +import ufl from pyop2 import op2 from pyop2.mpi import internal_comm, decref from pyop2.utils import as_tuple from firedrake.petsc import PETSc +from types import SimpleNamespace -class MatrixBase(object, metaclass=abc.ABCMeta): +class MatrixBase(ufl.Matrix): """A representation of the linear operator associated with a bilinear form and bcs. Explicitly assembled matrices and matrix-free matrix classes will derive from this - :arg a: the bilinear form this :class:`MatrixBase` represents. + :arg a: the bilinear form this :class:`MatrixBase` represents + or a tuple of the arguments it represents :arg bcs: an iterable of boundary conditions to apply to this :class:`MatrixBase`. May be `None` if there are no boundary @@ -20,12 +22,22 @@ class MatrixBase(object, metaclass=abc.ABCMeta): :arg mat_type: matrix type of assembled matrix, or 'matfree' for matrix-free """ def __init__(self, a, bcs, mat_type): - self.a = a + if isinstance(a, tuple): + self.a = None + test, trial = a + arguments = a + else: + self.a = a + test, trial = a.arguments() + arguments = None # Iteration over bcs must be in a parallel consistent order # (so we can't use a set, since the iteration order may differ # on different processes) + + ufl.Matrix.__init__(self, test.function_space(), trial.function_space()) + # Define arguments after `Matrix.__init__` since BaseForm sets `self._arguments` to None + self._arguments = arguments self.bcs = bcs - test, trial = a.arguments() self.comm = test.function_space().comm self._comm = internal_comm(self.comm) self.block_shape = (len(test.function_space()), @@ -36,6 +48,12 @@ def __init__(self, a, bcs, mat_type): Matrix type used in the assembly of the PETSc matrix: 'aij', 'baij', 'dense' or 'nest', or 'matfree' for matrix-free.""" + def arguments(self): + if self.a: + return self.a.arguments() + else: + return self._arguments + def __del__(self): if hasattr(self, "_comm"): decref(self._comm) @@ -100,7 +118,7 @@ class Matrix(MatrixBase): def __init__(self, a, bcs, mat_type, *args, **kwargs): # sets self._a, self._bcs, and self._mat_type - super(Matrix, self).__init__(a, bcs, mat_type) + MatrixBase.__init__(self, a, bcs, mat_type) options_prefix = kwargs.pop("options_prefix") self.M = op2.Mat(*args, mat_type=mat_type, **kwargs) self.petscmat = self.M.handle @@ -159,3 +177,34 @@ def assemble(self): # Ensures that if the matrix changed, the preconditioner is # updated if necessary. self.petscmat.assemble() + + +class AssembledMatrix(MatrixBase): + """A representation of a matrix that doesn't require knowing the underlying form. + This class wraps the relevant information for Python PETSc matrix. + + :arg a: A tuple of the arguments the matrix represents + + :arg bcs: an iterable of boundary conditions to apply to this + :class:`Matrix`. May be `None` if there are no boundary + conditions to apply. + + :arg petscmat: the already constructed petsc matrix this object represents. + """ + def __init__(self, a, bcs, petscmat, *args, **kwargs): + super(AssembledMatrix, self).__init__(a, bcs, "assembled") + + self.petscmat = petscmat + + # this allows call to self.M.handle without a new class + self.M = SimpleNamespace(handle=self.mat()) + + def mat(self): + return self.petscmat + + def __add__(self, other): + if isinstance(other, AssembledMatrix): + return self.petscmat + other.petscmat + else: + raise TypeError("Unable to add %s to AssembledMatrix" + % (type(other))) diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 90f4d08611..fcbdf552a9 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -87,7 +87,7 @@ class ImplicitMatrixContext(object): @PETSc.Log.EventDecorator() def __init__(self, a, row_bcs=[], col_bcs=[], fc_params=None, appctx=None): - from firedrake.assemble import OneFormAssembler + from firedrake.assemble import get_form_assembler self.a = a self.aT = adjoint(a) @@ -110,17 +110,20 @@ def __init__(self, a, row_bcs=[], col_bcs=[], test_space, trial_space = [ a.arguments()[i].function_space() for i in (0, 1) ] - from firedrake import function + from firedrake import function, cofunction + # Need a cofunction since y receives the assembled result of Ax + self._ystar = cofunction.Cofunction(test_space.dual()) self._y = function.Function(test_space) self._x = function.Function(trial_space) + self._xstar = cofunction.Cofunction(trial_space.dual()) # These are temporary storage for holding the BC # values during matvec application. _xbc is for # the action and ._ybc is for transpose. if len(self.bcs) > 0: - self._xbc = function.Function(trial_space) + self._xbc = cofunction.Cofunction(trial_space.dual()) if len(self.col_bcs) > 0: - self._ybc = function.Function(test_space) + self._ybc = cofunction.Cofunction(test_space.dual()) # Get size information from template vecs on test and trial spaces trial_vec = trial_space.dof_dset.layout_vec @@ -141,10 +144,10 @@ def __init__(self, a, row_bcs=[], col_bcs=[], elif isinstance(bc, EquationBCSplit): self.bcs_action.append(bc.reconstruct(action_x=self._x)) - self._assemble_action = OneFormAssembler(self.action, tensor=self._y, - bcs=self.bcs_action, - form_compiler_parameters=self.fc_params, - zero_bc_nodes=True).assemble + self._assemble_action = get_form_assembler(self.action, tensor=self._ystar, + bcs=self.bcs_action, + form_compiler_parameters=self.fc_params, + zero_bc_nodes=True) # For assembling action(adjoint(f), self._y) # Sorted list of equation bcs @@ -158,13 +161,13 @@ def __init__(self, a, row_bcs=[], col_bcs=[], for bc in self.bcs: for ebc in bc.sorted_equation_bcs(): self._assemble_actionT.append( - OneFormAssembler(action(adjoint(ebc.f), self._y), tensor=self._xbc, - form_compiler_parameters=self.fc_params).assemble) + get_form_assembler(action(adjoint(ebc.f), self._y), tensor=self._xbc, + form_compiler_parameters=self.fc_params)) # Domain last self._assemble_actionT.append( - OneFormAssembler(self.actionT, - tensor=self._x if len(self.bcs) == 0 else self._xbc, - form_compiler_parameters=self.fc_params).assemble) + get_form_assembler(self.actionT, + tensor=self._xstar if len(self.bcs) == 0 else self._xbc, + form_compiler_parameters=self.fc_params)) def __del__(self): if hasattr(self, "_comm"): @@ -172,22 +175,24 @@ def __del__(self): @cached_property def _diagonal(self): - from firedrake import Function + from firedrake import Cofunction assert self.on_diag - return Function(self._x.function_space()) + return Cofunction(self._x.function_space().dual()) @cached_property def _assemble_diagonal(self): - from firedrake.assemble import OneFormAssembler - return OneFormAssembler(self.a, tensor=self._diagonal, - form_compiler_parameters=self.fc_params, - diagonal=True).assemble + from firedrake.assemble import get_form_assembler + return get_form_assembler(self.a, tensor=self._diagonal, + form_compiler_parameters=self.fc_params, + diagonal=True) def getDiagonal(self, mat, vec): self._assemble_diagonal() + diagonal_func = self._diagonal.riesz_representation(riesz_map="l2") for bc in self.bcs: # Operator is identity on boundary nodes - bc.set(self._diagonal, 1) + bc.set(diagonal_func, 1) + self._diagonal.assign(diagonal_func.riesz_representation(riesz_map="l2")) with self._diagonal.dat.vec_ro as v: v.copy(vec) @@ -220,12 +225,12 @@ def mult(self, mat, X, Y): with self._xbc.dat.vec_wo as v: X.copy(v) for bc in self.row_bcs: - bc.set(self._y, self._xbc) + bc.set(self._ystar, self._xbc) else: for bc in self.row_bcs: - bc.zero(self._y) + bc.zero(self._ystar) - with self._y.dat.vec_ro as v: + with self._ystar.dat.vec_ro as v: v.copy(Y) @PETSc.Log.EventDecorator() @@ -300,14 +305,14 @@ def multTranspose(self, mat, Y, X): if len(self.bcs) > 0: # Accumulate values in self._x - self._x.dat.zero() + self._xstar.dat.zero() # Apply actionTs in sorted order for aT, obj in zip(self._assemble_actionT, self.objs_actionT): # zero columns associated with DirichletBCs/EquationBCs for obc in obj.bcs: obc.zero(self._y) aT() - self._x += self._xbc + self._xstar += self._xbc else: # No DirichletBC/EquationBC # There is only a single element in the list (for the domain equation). @@ -321,12 +326,12 @@ def multTranspose(self, mat, Y, X): with self._ybc.dat.vec_wo as v: Y.copy(v) for bc in self.col_bcs: - bc.set(self._x, self._ybc) + bc.set(self._xstar, self._ybc) else: for bc in self.col_bcs: - bc.zero(self._x) + bc.zero(self._xstar) - with self._x.dat.vec_ro as v: + with self._xstar.dat.vec_ro as v: v.copy(X) def view(self, mat, viewer=None): diff --git a/firedrake/mg/interface.py b/firedrake/mg/interface.py index 0029663b30..72bf260c21 100644 --- a/firedrake/mg/interface.py +++ b/firedrake/mg/interface.py @@ -1,10 +1,10 @@ from pyop2 import op2 import firedrake +from firedrake import ufl_expr from firedrake.petsc import PETSc from . import utils from . import kernels -from ufl.domain import extract_unique_domain __all__ = ["prolong", "restrict", "inject"] @@ -45,8 +45,8 @@ def prolong(coarse, fine): src.copy(dest) return fine - hierarchy, coarse_level = utils.get_level(extract_unique_domain(coarse)) - _, fine_level = utils.get_level(extract_unique_domain(fine)) + hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse)) + _, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine)) refinements_per_level = hierarchy.refinements_per_level repeat = (fine_level - coarse_level)*refinements_per_level next_level = coarse_level * refinements_per_level @@ -106,8 +106,8 @@ def restrict(fine_dual, coarse_dual): src.copy(dest) return coarse_dual - hierarchy, coarse_level = utils.get_level(extract_unique_domain(coarse_dual)) - _, fine_level = utils.get_level(extract_unique_domain(fine_dual)) + hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse_dual)) + _, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine_dual)) refinements_per_level = hierarchy.refinements_per_level repeat = (fine_level - coarse_level)*refinements_per_level next_level = fine_level * refinements_per_level @@ -180,10 +180,10 @@ def inject(fine, coarse): # solve inner(u_c, v_c)*dx_c == inner(f, v_c)*dx_c kernel, dg = kernels.inject_kernel(Vf, Vc) - hierarchy, coarse_level = utils.get_level(extract_unique_domain(coarse)) + hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse)) if dg and not hierarchy.nested: raise NotImplementedError("Sorry, we can't do supermesh projections yet!") - _, fine_level = utils.get_level(extract_unique_domain(fine)) + _, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine)) refinements_per_level = hierarchy.refinements_per_level repeat = (fine_level - coarse_level)*refinements_per_level next_level = fine_level * refinements_per_level diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index 5717a79b82..e3c7832742 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -19,6 +19,7 @@ from functools import partial from firedrake.function import Function +from firedrake.cofunction import Cofunction from firedrake.vector import Vector from firedrake.constant import Constant from firedrake_citations import Citations @@ -86,9 +87,11 @@ def backward(ctx, grad_output): F = ctx.metadata['F'] V = ctx.metadata['V_output'] # Convert PyTorch gradient to Firedrake - adj_input = from_torch(grad_output, V) - if isinstance(adj_input, Function): - adj_input = adj_input.vector() + V_adj = V.dual() if V else V + adj_input = from_torch(grad_output, V_adj) + if isinstance(adj_input, Constant) and adj_input.ufl_shape == (): + # This will later on result in an `AdjFloat` adjoint input instead of a Constant + adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional adj_output = F.derivative(adj_input=adj_input) @@ -140,7 +143,7 @@ def _extract_function_space(x): firedrake.functionspaceimpl.WithGeometry or None Extracted function space. """ - if isinstance(x, Function): + if isinstance(x, (Function, Cofunction)): return x.function_space() elif isinstance(x, Vector): return _extract_function_space(x.function) @@ -172,7 +175,7 @@ def to_torch(x, gather=False, batched=True, **kwargs): torch.Tensor PyTorch tensor representing the Firedrake object `x`. """ - if isinstance(x, (Function, Vector)): + if isinstance(x, (Function, Cofunction, Vector)): if gather: # Gather data from all processes x_P = torch.tensor(x.vector().gather(), **kwargs) diff --git a/firedrake/parloops.py b/firedrake/parloops.py index f36622c2cf..08cb57e3d8 100644 --- a/firedrake/parloops.py +++ b/firedrake/parloops.py @@ -4,7 +4,7 @@ import collections from ufl.indexed import Indexed -from ufl.domain import join_domains, extract_domains +from ufl.domain import join_domains from pyop2 import op2, READ, WRITE, RW, INC, MIN, MAX import loopy @@ -12,6 +12,7 @@ from firedrake.parameters import target from firedrake import constant +from firedrake.ufl_expr import extract_domains from firedrake.petsc import PETSc from cachetools import LRUCache diff --git a/firedrake/preconditioners/fdm.py b/firedrake/preconditioners/fdm.py index 6aba95b301..68317679aa 100644 --- a/firedrake/preconditioners/fdm.py +++ b/firedrake/preconditioners/fdm.py @@ -9,6 +9,7 @@ from firedrake.preconditioners.facet_split import split_dofs, restricted_dofs from firedrake.formmanipulation import ExtractSubBlock from firedrake.function import Function +from firedrake.cofunction import Cofunction from firedrake.functionspace import FunctionSpace from firedrake.ufl_expr import TestFunction, TestFunctions, TrialFunctions from firedrake.utils import cached_property @@ -437,7 +438,7 @@ def assemble_coefficients(self, J, fcp, block_diagonal=True): assembly_callables.append(ctx._assemble_block_diagonal) else: from firedrake.assemble import OneFormAssembler - tensor = Function(Z) + tensor = Function(Z.dual()) coefficients["beta"] = tensor.subfunctions[0] coefficients["alpha"] = tensor.subfunctions[1] assembly_callables.append(OneFormAssembler(mixed_form, tensor=tensor, diagonal=True, @@ -1627,7 +1628,7 @@ def assemble_coefficients(self, J, fcp): # assemble second order coefficient if not isinstance(alpha, ufl.constantvalue.Zero): Q = FunctionSpace(mesh, ufl.TensorElement(DG, shape=alpha.ufl_shape)) - tensor = coefficients.setdefault("alpha", Function(Q)) + tensor = coefficients.setdefault("alpha", Function(Q.dual())) assembly_callables.append(OneFormAssembler(ufl.inner(TestFunction(Q), alpha)*dx, tensor=tensor, form_compiler_parameters=fcp).assemble) @@ -1649,7 +1650,7 @@ def assemble_coefficients(self, J, fcp): # keep diagonal beta = ufl.diag_vector(beta) Q = FunctionSpace(mesh, ufl.TensorElement(DG, shape=beta.ufl_shape) if beta.ufl_shape else DG) - tensor = coefficients.setdefault("beta", Function(Q)) + tensor = coefficients.setdefault("beta", Function(Q.dual())) assembly_callables.append(OneFormAssembler(ufl.inner(TestFunction(Q), beta)*dx, tensor=tensor, form_compiler_parameters=fcp).assemble) @@ -1672,12 +1673,12 @@ def assemble_coefficients(self, J, fcp): G = G * abs(ufl.JacobianDeterminant(mesh)) Q = FunctionSpace(mesh, ufl.TensorElement(DGT, shape=G.ufl_shape)) - tensor = coefficients.setdefault("Gq_facet", Function(Q)) + tensor = coefficients.setdefault("Gq_facet", Function(Q.dual())) assembly_callables.append(OneFormAssembler(ifacet_inner(TestFunction(Q), G), tensor=tensor, form_compiler_parameters=fcp).assemble) PT = Piola.T Q = FunctionSpace(mesh, ufl.TensorElement(DGT, shape=PT.ufl_shape)) - tensor = coefficients.setdefault("PT_facet", Function(Q)) + tensor = coefficients.setdefault("PT_facet", Function(Q.dual())) assembly_callables.append(OneFormAssembler(ifacet_inner(TestFunction(Q), PT), tensor=tensor, form_compiler_parameters=fcp).assemble) @@ -1701,7 +1702,7 @@ def assemble_coefficients(self, J, fcp): ds_ext = ufl.Measure(itype, domain=mesh, subdomain_id=it.subdomain_id(), metadata=md) forms.append(ufl.inner(test, beta)*ds_ext) - tensor = coefficients.setdefault("bcflags", Function(Q)) + tensor = coefficients.setdefault("bcflags", Function(Q.dual())) if len(forms): form = sum(forms) if len(form.arguments()) == 1: @@ -1819,7 +1820,7 @@ def extrude_interior_facet_maps(V): local_facet_data_fun: maps interior facets to the local facet numbering in the two cells sharing it, nfacets: the total number of interior facets owned by this process """ - if isinstance(V, Function): + if isinstance(V, (Function, Cofunction)): V = V.function_space() mesh = V.mesh() intfacets = mesh.interior_facets diff --git a/firedrake/projection.py b/firedrake/projection.py index 1c6ae84644..4b47273f73 100644 --- a/firedrake/projection.py +++ b/firedrake/projection.py @@ -171,7 +171,7 @@ def solve(x, b): @cached_property def residual(self): - return firedrake.Function(self.target.function_space()) + return firedrake.Cofunction(self.target.function_space().dual()) @abc.abstractproperty def rhs(self): diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index fa1c50ea37..8dcad2ee7d 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -18,9 +18,10 @@ from collections import OrderedDict, namedtuple, defaultdict -from ufl import Coefficient, Constant +from ufl import Constant +from ufl.coefficient import BaseCoefficient -from firedrake.function import Function +from firedrake.function import Function, Cofunction from firedrake.utils import cached_property from itertools import chain, count @@ -236,7 +237,7 @@ def coeff_map(self): coeff_map[m].update(c.indices[0]) else: m = self.coefficients().index(c) - split_map = tuple(range(len(c.subfunctions))) if isinstance(c, Function) or isinstance(c, Constant) else tuple(range(1)) + split_map = tuple(range(len(c.subfunctions))) if isinstance(c, Function) or isinstance(c, Constant) or isinstance(c, Cofunction) else tuple(range(1)) coeff_map[m].update(split_map) return tuple((k, tuple(sorted(v)))for k, v in coeff_map.items()) @@ -434,12 +435,12 @@ def integrals(self): def __new__(cls, function): if isinstance(function, AssembledVector): return function - elif isinstance(function, Coefficient): + elif isinstance(function, BaseCoefficient): self = super().__new__(cls) self._function = function return self else: - raise TypeError("Expecting a Coefficient or AssembledVector (not a %r)" % + raise TypeError("Expecting a BaseCoefficient or AssembledVector (not a %r)" % type(function)) @cached_property @@ -514,7 +515,7 @@ def __new__(cls, function, expr, indices): block = Block(expr, indices) split_functions = block.form if isinstance(split_functions, tuple) \ - and all(isinstance(f, Coefficient) for f in split_functions): + and all(isinstance(f, BaseCoefficient) for f in split_functions): self = TensorBase.__new__(cls) self._function = split_functions self._indices = indices @@ -522,7 +523,7 @@ def __new__(cls, function, expr, indices): self._block = block return self else: - raise TypeError("Expecting a tuple of Coefficients (not a %r)" % + raise TypeError("Expecting a tuple of BaseCoefficients (not a %r)" % type(split_functions)) @cached_property diff --git a/firedrake/slate/static_condensation/hybridization.py b/firedrake/slate/static_condensation/hybridization.py index 37d0df784c..0f2cad616d 100644 --- a/firedrake/slate/static_condensation/hybridization.py +++ b/firedrake/slate/static_condensation/hybridization.py @@ -37,7 +37,7 @@ def initialize(self, pc): A KSP is created for the Lagrange multiplier system. """ - from firedrake import (FunctionSpace, Function, Constant, + from firedrake import (FunctionSpace, Cofunction, Function, Constant, TrialFunction, TrialFunctions, TestFunction, DirichletBC) from firedrake.assemble import allocate_matrix, OneFormAssembler, TwoFormAssembler @@ -92,10 +92,10 @@ def initialize(self, pc): # Set up the functions for the original, hybridized # and schur complement systems - self.broken_solution = Function(V_d) + self.broken_solution = Cofunction(V_d.dual()) self.broken_residual = Function(V_d) self.trace_solution = Function(TraceSpace) - self.unbroken_solution = Function(V) + self.unbroken_solution = Cofunction(V.dual()) self.unbroken_residual = Function(V) shapes = (V[self.vidx].finat_element.space_dimension(), @@ -203,7 +203,7 @@ def initialize(self, pc): schur_rhs, schur_comp = self.schur_builder.build_schur(AssembledVector(self.broken_residual)) # Assemble the Schur complement operator and right-hand side - self.schur_rhs = Function(TraceSpace) + self.schur_rhs = Cofunction(TraceSpace.dual()) self._assemble_Srhs = OneFormAssembler(schur_rhs, tensor=self.schur_rhs, form_compiler_parameters=self.ctx.fc_params).assemble diff --git a/firedrake/slate/static_condensation/scpc.py b/firedrake/slate/static_condensation/scpc.py index 3ebf22186b..011a802627 100644 --- a/firedrake/slate/static_condensation/scpc.py +++ b/firedrake/slate/static_condensation/scpc.py @@ -30,6 +30,7 @@ def initialize(self, pc): from firedrake.assemble import allocate_matrix, OneFormAssembler, TwoFormAssembler from firedrake.bcs import DirichletBC from firedrake.function import Function + from firedrake.cofunction import Cofunction from firedrake.functionspace import FunctionSpace from firedrake.parloops import par_loop, INC from ufl import dx @@ -76,9 +77,9 @@ def initialize(self, pc): mat_type = PETSc.Options().getString(prefix + "mat_type", "aij") self.c_field = c_field - self.condensed_rhs = Function(Vc) + self.condensed_rhs = Cofunction(Vc.dual()) self.residual = Function(W) - self.solution = Function(W) + self.solution = Cofunction(W.dual()) shapes = (Vc.finat_element.space_dimension(), np.prod(Vc.shape)) diff --git a/firedrake/solving.py b/firedrake/solving.py index 0fbb9f44d0..3c8385621c 100644 --- a/firedrake/solving.py +++ b/firedrake/solving.py @@ -150,7 +150,7 @@ def _solve_varproblem(*args, **kwargs): appctx = kwargs.get("appctx", {}) # Solve linear variational problem - if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form): + if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, (ufl.Form, ufl.Cofunction)): # Create problem problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp, form_compiler_parameters=form_compiler_parameters) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 5055582f20..f307b9a382 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -4,7 +4,7 @@ from pyop2 import op2 from firedrake_configuration import get_config -from firedrake import function, dmhooks +from firedrake import function, cofunction, dmhooks from firedrake.exceptions import ConvergenceError from firedrake.petsc import PETSc from firedrake.formmanipulation import ExtractSubBlock @@ -178,7 +178,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, post_jacobian_callback=None, post_function_callback=None, options_prefix=None, transfer_manager=None): - from firedrake.assemble import OneFormAssembler + from firedrake.assemble import get_form_assembler if pmat_type is None: pmat_type = mat_type @@ -233,9 +233,9 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, self.bcs_J = tuple(bc.extract_form('J') for bc in problem.bcs) self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs) - self._assemble_residual = OneFormAssembler(self.F, self._F, self.bcs_F, - form_compiler_parameters=self.fcp, - zero_bc_nodes=True).assemble + self._assemble_residual = get_form_assembler(self.F, self._F, bcs=self.bcs_F, + form_compiler_parameters=self.fcp, + zero_bc_nodes=True) self._jacobian_assembled = False self._splits = {} @@ -510,9 +510,10 @@ def _jac(self): @cached_property def _assemble_jac(self): - from firedrake.assemble import TwoFormAssembler - return TwoFormAssembler(self.J, self._jac, bcs=self.bcs_J, - form_compiler_parameters=self.fcp).assemble + from firedrake.assemble import get_form_assembler + return get_form_assembler(self.J, self._jac, bcs=self.bcs_J, + mat_type=self.mat_type, + form_compiler_parameters=self.fcp) @cached_property def is_mixed(self): @@ -533,10 +534,10 @@ def _pjac(self): @cached_property def _assemble_pjac(self): - from firedrake.assemble import TwoFormAssembler - return TwoFormAssembler(self.Jp, self._pjac, bcs=self.bcs_Jp, - form_compiler_parameters=self.fcp).assemble + from firedrake.assemble import get_form_assembler + return get_form_assembler(self.Jp, self._pjac, bcs=self.bcs_Jp, + form_compiler_parameters=self.fcp) @cached_property def _F(self): - return function.Function(self.F.arguments()[0].function_space()) + return cofunction.Cofunction(self.F.arguments()[0].function_space().dual()) diff --git a/firedrake/ufl_expr.py b/firedrake/ufl_expr.py index 3bca83fb7a..98fe617949 100644 --- a/firedrake/ufl_expr.py +++ b/firedrake/ufl_expr.py @@ -1,15 +1,16 @@ import ufl import ufl.argument +from ufl.duals import is_dual from ufl.split_functions import split from ufl.algorithms import extract_arguments, extract_coefficients import firedrake -from firedrake import utils +from firedrake import utils, function, cofunction from firedrake.constant import Constant from firedrake.petsc import PETSc -__all__ = ['Argument', 'TestFunction', 'TrialFunction', +__all__ = ['Argument', 'Coargument', 'TestFunction', 'TrialFunction', 'TestFunctions', 'TrialFunctions', 'derivative', 'adjoint', 'action', 'CellSize', 'FacetNormal'] @@ -29,6 +30,12 @@ class Argument(ufl.argument.Argument): :func:`TestFunction`, with a number of ``1`` it is used as a :func:`TrialFunction`. """ + + def __new__(cls, *args, **kwargs): + if args[0] and is_dual(args[0]): + return Coargument(*args, **kwargs) + return super().__new__(cls, *args, **kwargs) + def __init__(self, function_space, number, part=None): super(Argument, self).__init__(function_space.ufl_function_space(), number, part=part) @@ -70,6 +77,76 @@ def reconstruct(self, function_space=None, return Argument(function_space, number, part=part) +class Coargument(ufl.argument.Coargument): + """Representation of an argument to a form in a dual space. + + :arg function_space: the :class:`.FunctionSpace` the argument + corresponds to. + :arg number: the number of the argument being constructed. + :kwarg part: optional index (mostly ignored). + """ + + def __init__(self, function_space, number, part=None): + super(Coargument, self).__init__(function_space.ufl_function_space(), + number, part=part) + self._function_space = function_space + + @utils.cached_property + def cell_node_map(self): + return self.function_space().cell_node_map + + @utils.cached_property + def interior_facet_node_map(self): + return self.function_space().interior_facet_node_map + + @utils.cached_property + def exterior_facet_node_map(self): + return self.function_space().exterior_facet_node_map + + def function_space(self): + return self._function_space + + def make_dat(self): + return self.function_space().make_dat() + + def _analyze_form_arguments(self, outer_form=None): + # Returns the argument found in the Coargument object + self._coefficients = () + # Coarguments map from V* to V*, i.e. V* -> V*, or equivalently V* x V -> R. + # So they have one argument in the primal space and one in the dual space. + # However, when they are composed with linear forms with dual arguments, such as BaseFormOperators, + # the primal argument is discarded when analysing the argument as Coarguments. + if not outer_form: + self._arguments = (Argument(self.function_space().dual(), 0), self) + else: + self._arguments = (self,) + + def reconstruct(self, function_space=None, + number=None, part=None): + if function_space is None or function_space == self.function_space(): + function_space = self.function_space() + if number is None or number == self._number: + number = self._number + if part is None or part == self._part: + part = self._part + if number is self._number and part is self._part \ + and function_space is self.function_space(): + return self + if not isinstance(number, int): + raise TypeError(f"Expecting an int, not {number}") + if function_space.ufl_element().value_shape() != self.ufl_element().value_shape(): + raise ValueError("Cannot reconstruct an Coargument with a different value shape.") + return Coargument(function_space, number, part=part) + + def equals(self, other): + if type(other) is not Coargument: + return False + if self is other: + return True + return (self._function_space == other._function_space + and self._number == other._number and self._part == other._part) + + @PETSc.Log.EventDecorator() def TestFunction(function_space, part=None): """Build a test function on the specified function space. @@ -170,7 +247,7 @@ def derivative(form, u, du=None, coefficient_derivatives=None): coords = mesh.coordinates u = ufl.SpatialCoordinate(mesh) V = coords.function_space() - elif isinstance(uc, firedrake.Function): + elif isinstance(uc, (firedrake.Function, firedrake.Cofunction)): V = uc.function_space() elif isinstance(uc, firedrake.Constant): if uc.ufl_shape != (): @@ -277,3 +354,41 @@ def FacetNormal(mesh): """ mesh.init() return ufl.FacetNormal(mesh) + + +def extract_domains(func): + """Extract the domain from `func`. + + Parameters + ---------- + x : firedrake.function.Function, firedrake.cofunction.Cofunction, or firedrake.constant.Constant + The function to extract the domain from. + + Returns + ------- + list of firedrake.mesh.MeshGeometry + Extracted domains. + """ + if isinstance(func, (function.Function, cofunction.Cofunction)): + return [func.function_space().mesh()] + else: + return ufl.domain.extract_domains(func) + + +def extract_unique_domain(func): + """Extract the single unique domain `func` is defined on. + + Parameters + ---------- + x : firedrake.function.Function, firedrake.cofunction.Cofunction, or firedrake.constant.Constant + The function to extract the domain from. + + Returns + ------- + list of firedrake.mesh.MeshGeometry + Extracted domains. + """ + if isinstance(func, (function.Function, cofunction.Cofunction)): + return func.function_space().mesh() + else: + return ufl.domain.extract_unique_domain(func) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 476c70950c..55bfa00b57 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -18,16 +18,16 @@ def check_pde_args(F, J, Jp): - if not isinstance(F, (ufl.Form, slate.slate.TensorBase)): - raise TypeError("Provided residual is a '%s', not a Form or Slate Tensor" % type(F).__name__) + if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)): + raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__) if len(F.arguments()) != 1: raise ValueError("Provided residual is not a linear form") - if not isinstance(J, (ufl.Form, slate.slate.TensorBase)): - raise TypeError("Provided Jacobian is a '%s', not a Form or Slate Tensor" % type(J).__name__) + if not isinstance(J, (ufl.BaseForm, slate.slate.TensorBase)): + raise TypeError("Provided Jacobian is a '%s', not a BaseForm or Slate Tensor" % type(J).__name__) if len(J.arguments()) != 2: raise ValueError("Provided Jacobian is not a bilinear form") - if Jp is not None and not isinstance(Jp, (ufl.Form, slate.slate.TensorBase)): - raise TypeError("Provided preconditioner is a '%s', not a Form or Slate Tensor" % type(Jp).__name__) + if Jp is not None and not isinstance(Jp, (ufl.BaseForm, slate.slate.TensorBase)): + raise TypeError("Provided preconditioner is a '%s', not a BaseForm or Slate Tensor" % type(Jp).__name__) if Jp is not None and len(Jp.arguments()) != 2: raise ValueError("Provided preconditioner is not a bilinear form") @@ -311,10 +311,10 @@ def __init__(self, a, L, u, bcs=None, aP=None, # In the linear case, the Jacobian is the equation LHS. J = a # Jacobian is checked in superclass, but let's check L here. - if not isinstance(L, (ufl.Form, slate.slate.TensorBase)) and L == 0: + if not isinstance(L, (ufl.Form, ufl.Cofunction, slate.slate.TensorBase)) and L == 0: F = ufl_expr.action(J, u) else: - if not isinstance(L, (ufl.Form, slate.slate.TensorBase)): + if not isinstance(L, (ufl.Form, ufl.Cofunction, slate.slate.TensorBase)): raise TypeError("Provided RHS is a '%s', not a Form or Slate Tensor" % type(L).__name__) if len(L.arguments()) != 1: raise ValueError("Provided RHS is not a linear form") diff --git a/firedrake/vector.py b/firedrake/vector.py index bd4febe684..23915aa5fd 100644 --- a/firedrake/vector.py +++ b/firedrake/vector.py @@ -2,6 +2,7 @@ import numpy as np +from ufl.form import ZeroBaseForm from pyop2.mpi import internal_comm, decref import firedrake @@ -53,7 +54,7 @@ def __init__(self, x): """ if isinstance(x, Vector): self.function = type(x.function)(x.function) - elif isinstance(x, firedrake.Function): + elif isinstance(x, (firedrake.Function, firedrake.Cofunction)): self.function = x else: raise RuntimeError("Don't know how to build a Vector from a %r" % type(x)) @@ -111,6 +112,8 @@ def __rmul__(self, other): def __add__(self, other): """Add other to self""" sum = self.copy() + if isinstance(other, ZeroBaseForm): + return sum try: sum.dat += other.dat except AttributeError: @@ -122,6 +125,8 @@ def __radd__(self, other): def __iadd__(self, other): """Add other to self""" + if isinstance(other, ZeroBaseForm): + return self try: self.dat += other.dat except AttributeError: diff --git a/tests/multigrid/test_p_multigrid.py b/tests/multigrid/test_p_multigrid.py index 3954536f1f..456dec13b9 100644 --- a/tests/multigrid/test_p_multigrid.py +++ b/tests/multigrid/test_p_multigrid.py @@ -324,7 +324,11 @@ def test_p_multigrid_mixed(mat_type): "pmg_mg_levels": relax, "pmg_mg_coarse": coarse} - basis = VectorSpaceBasis([assemble(TestFunction(Z.sub(1))*dx)]) + # Make the Function spanning the nullspace + c_basis = assemble(TestFunction(Z.sub(1))*dx) + f_basis = Function(c_basis.function_space().dual(), val=c_basis.dat) + + basis = VectorSpaceBasis([f_basis]) basis.orthonormalize() nullspace = MixedVectorSpaceBasis(Z, [Z.sub(0), basis]) problem = NonlinearVariationalProblem(F, z, bcs) diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 9e9f7dfbcc..407936f187 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -314,9 +314,9 @@ def test_interpolate_hessian_linear_expr(): g = f.copy(deepcopy=True) dJdm = J.block_variable.tlm_value - assert isinstance(f.block_variable.adj_value, Vector) - assert isinstance(f.block_variable.hessian_value, Vector) - Hm = f.block_variable.hessian_value.inner(h.vector()) + assert isinstance(f.block_variable.adj_value, Cofunction) + assert isinstance(f.block_variable.hessian_value, Cofunction) + Hm = f.block_variable.hessian_value.dat.inner(h.dat) # If the new interpolate block has the right hessian, taylor test # convergence rate should be as for the unmodified test. assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 @@ -372,9 +372,9 @@ def test_interpolate_hessian_nonlinear_expr(): g = f.copy(deepcopy=True) dJdm = J.block_variable.tlm_value - assert isinstance(f.block_variable.adj_value, Vector) - assert isinstance(f.block_variable.hessian_value, Vector) - Hm = f.block_variable.hessian_value.inner(h.vector()) + assert isinstance(f.block_variable.adj_value, Cofunction) + assert isinstance(f.block_variable.hessian_value, Cofunction) + Hm = f.block_variable.hessian_value.dat.inner(h.dat) # If the new interpolate block has the right hessian, taylor test # convergence rate should be as for the unmodified test. assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 @@ -434,9 +434,9 @@ def test_interpolate_hessian_nonlinear_expr_multi(): g = f.copy(deepcopy=True) dJdm = J.block_variable.tlm_value - assert isinstance(f.block_variable.adj_value, Vector) - assert isinstance(f.block_variable.hessian_value, Vector) - Hm = f.block_variable.hessian_value.inner(h.vector()) + assert isinstance(f.block_variable.adj_value, Cofunction) + assert isinstance(f.block_variable.hessian_value, Cofunction) + Hm = f.block_variable.hessian_value.dat.inner(h.dat) # If the new interpolate block has the right hessian, taylor test # convergence rate should be as for the unmodified test. assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 @@ -638,9 +638,9 @@ def test_supermesh_project_hessian(vector): tape.evaluate_hessian() dJdm = J.block_variable.tlm_value - assert isinstance(source.block_variable.adj_value, Vector) - assert isinstance(source.block_variable.hessian_value, Vector) - Hm = source.block_variable.hessian_value.inner(h.vector()) + assert isinstance(source.block_variable.adj_value, Cofunction) + assert isinstance(source.block_variable.hessian_value, Cofunction) + Hm = source.block_variable.hessian_value.dat.inner(h.dat) assert taylor_test(rf, source, h, dJdm=dJdm, Hm=Hm) > 2.9 diff --git a/tests/regression/test_assemble.py b/tests/regression/test_assemble.py index cd2d2db0d2..f604b24938 100644 --- a/tests/regression/test_assemble.py +++ b/tests/regression/test_assemble.py @@ -86,7 +86,7 @@ def M(fs): def test_one_form(M, f): one_form = assemble(action(M, f)) - assert isinstance(one_form, Function) + assert isinstance(one_form, Cofunction) for f in one_form.subfunctions: if f.function_space().rank == 2: assert abs(f.dat.data.sum() - 0.5*sum(f.function_space().shape)) < 1.0e-12 diff --git a/tests/regression/test_assemble_baseform.py b/tests/regression/test_assemble_baseform.py new file mode 100644 index 0000000000..a7bb9a77c8 --- /dev/null +++ b/tests/regression/test_assemble_baseform.py @@ -0,0 +1,294 @@ +import pytest +import numpy as np +from firedrake import * +from firedrake.assemble import allocate_matrix +from firedrake.utils import ScalarType +import ufl + + +@pytest.fixture(scope='module') +def mesh(): + return UnitSquareMesh(5, 5) + + +@pytest.fixture(scope='module', params=['cg1', 'vcg1', 'tcg1', + 'cg1cg1', 'cg1cg1[0]', 'cg1cg1[1]', + 'cg1vcg1[0]', 'cg1vcg1[1]', + 'cg1dg0', 'cg1dg0[0]', 'cg1dg0[1]', + 'cg2dg1', 'cg2dg1[0]', 'cg2dg1[1]']) +def fs(request, mesh): + cg1 = FunctionSpace(mesh, "CG", 1) + cg2 = FunctionSpace(mesh, "CG", 2) + vcg1 = VectorFunctionSpace(mesh, "CG", 1) + tcg1 = TensorFunctionSpace(mesh, "CG", 1) + dg0 = FunctionSpace(mesh, "DG", 0) + dg1 = FunctionSpace(mesh, "DG", 1) + return {'cg1': cg1, + 'vcg1': vcg1, + 'tcg1': tcg1, + 'cg1cg1': cg1*cg1, + 'cg1cg1[0]': (cg1*cg1)[0], + 'cg1cg1[1]': (cg1*cg1)[1], + 'cg1vcg1': cg1*vcg1, + 'cg1vcg1[0]': (cg1*vcg1)[0], + 'cg1vcg1[1]': (cg1*vcg1)[1], + 'cg1dg0': cg1*dg0, + 'cg1dg0[0]': (cg1*dg0)[0], + 'cg1dg0[1]': (cg1*dg0)[1], + 'cg2dg1': cg2*dg1, + 'cg2dg1[0]': (cg2*dg1)[0], + 'cg2dg1[1]': (cg2*dg1)[1]}[request.param] + + +@pytest.fixture +def f(fs): + f = Function(fs, name="f") + f_split = f.subfunctions + x = SpatialCoordinate(fs.mesh())[0] + + # NOTE: interpolation of UFL expressions into mixed + # function spaces is not yet implemented + for fi in f_split: + fs_i = fi.function_space() + if fs_i.rank == 1: + fi.interpolate(as_vector((x,) * fs_i.value_size)) + elif fs_i.rank == 2: + fi.interpolate(as_tensor([[x for i in range(fs_i.mesh().geometric_dimension())] + for j in range(fs_i.rank)])) + else: + fi.interpolate(x) + return f + + +@pytest.fixture +def one(fs): + one = Function(fs, name="one") + ones = one.subfunctions + + # NOTE: interpolation of UFL expressions into mixed + # function spaces is not yet implemented + for fi in ones: + fs_i = fi.function_space() + if fs_i.rank == 1: + fi.interpolate(Constant((1.0,) * fs_i.value_size)) + elif fs_i.rank == 2: + fi.interpolate(Constant([[1.0 for i in range(fs_i.mesh().geometric_dimension())] + for j in range(fs_i.rank)])) + else: + fi.interpolate(Constant(1.0)) + return one + + +@pytest.fixture +def M(fs): + uhat = TrialFunction(fs) + v = TestFunction(fs) + return inner(uhat, v) * dx + + +@pytest.fixture +def a(fs, f): + v = TestFunction(fs) + return inner(f, v) * dx + + +def test_assemble_cofun(a): + res = assemble(a) + assert isinstance(res, Cofunction) + + +def test_assemble_matrix(M): + res = assemble(M) + assert isinstance(res, ufl.Matrix) + + +def test_assemble_adjoint(M): + res = assemble(adjoint(M)) + assembledM = assemble(M) + res2 = assemble(adjoint(assembledM)) + assert isinstance(res, ufl.Matrix) + assert res.M.handle == res.petscmat + assert np.allclose(res.M.handle[:, :], res2.M.handle[:, :], rtol=1e-14) + + +def test_assemble_action(M, f): + res = assemble(action(M, f)) + assembledM = assemble(M) + res2 = assemble(action(assembledM, f)) + assert isinstance(res2, Cofunction) + assert isinstance(res, Cofunction) + for f, f2 in zip(res.subfunctions, res2.subfunctions): + assert abs(f.dat.data.sum() - f2.dat.data.sum()) < 1.0e-12 + if f.function_space().rank == 2: + assert abs(f.dat.data.sum() - 0.5*sum(f.function_space().shape)) < 1.0e-12 + else: + assert abs(f.dat.data.sum() - 0.5*f.function_space().value_size) < 1.0e-12 + + +def test_vector_formsum(a): + res = assemble(a) + preassemble = assemble(a + a) + formsum = res + a + res2 = assemble(formsum) + + assert isinstance(formsum, ufl.form.FormSum) + assert isinstance(res2, Cofunction) + assert isinstance(preassemble, Cofunction) + for f, f2 in zip(preassemble.subfunctions, res2.subfunctions): + assert abs(f.dat.data.sum() - f2.dat.data.sum()) < 1.0e-12 + + +def test_matrix_formsum(M): + res = assemble(M) + sumfirst = assemble(M+M) + formsum = res + M + assert isinstance(formsum, ufl.form.FormSum) + res2 = assemble(formsum) + assert isinstance(res2, ufl.Matrix) + assert np.allclose(sumfirst.petscmat[:, :], + res2.petscmat[:, :], rtol=1e-14) + + +def test_zero_form(M, f, one): + zero_form = assemble(action(action(M, f), one)) + assert isinstance(zero_form, ScalarType.type) + assert abs(zero_form - 0.5 * np.prod(f.ufl_shape)) < 1.0e-12 + + +def test_tensor_copy(a, M): + + # 1-form tensor + V = a.arguments()[0].function_space() + tensor = Cofunction(V.dual()) + formsum = assemble(a) + a + res = assemble(formsum, tensor=tensor) + + assert isinstance(formsum, ufl.form.FormSum) + assert isinstance(res, Cofunction) + for f, f2 in zip(res.subfunctions, tensor.subfunctions): + assert abs(f.dat.data.sum() - f2.dat.data.sum()) < 1.0e-12 + + # 2-form tensor + tensor = allocate_matrix(M) + formsum = assemble(M) + M + res = assemble(formsum, tensor=tensor) + + assert isinstance(formsum, ufl.form.FormSum) + assert isinstance(res, ufl.Matrix) + assert np.allclose(res.petscmat[:, :], + tensor.petscmat[:, :], rtol=1e-14) + + +def test_cofunction_assign(a, M, f): + c1 = assemble(a) + # Scale the action to obtain a different value than c1 + c2 = assemble(2 * action(M, f)) + assert isinstance(c1, Cofunction) + assert isinstance(c2, Cofunction) + + # Assign Cofunction to Cofunction + c1.assign(c2) + for a, b in zip(c1.subfunctions, c2.subfunctions): + assert np.allclose(a.dat.data, b.dat.data) + + # Assign BaseForm to Cofunction + c1.assign(action(M, f)) + for a, b in zip(c1.subfunctions, c2.subfunctions): + assert np.allclose(a.dat.data, 0.5 * b.dat.data) + + +def test_cofunction_riesz_representation(a): + # Get a Cofunction + c = assemble(a) + assert isinstance(c, Cofunction) + + V = c.function_space().dual() + u = TrialFunction(V) + v = TestFunction(V) + + # Define Riesz maps + riesz_maps = {'L2': inner(u, v) * dx, + 'H1': (inner(u, v) + inner(grad(u), grad(v))) * dx, + 'l2': None} + + # Check Riesz representation for each map + for riesz_map, mass in riesz_maps.items(): + + # Get Riesz representation of c + r = c.riesz_representation(riesz_map=riesz_map) + + assert isinstance(r, Function) + assert r.function_space() == V + + if mass: + M = assemble(mass) + Mr = Function(V) + with r.dat.vec_ro as v_vec: + with Mr.dat.vec as res_vec: + M.petscmat.mult(v_vec, res_vec) + else: + # l2 mass matrix is identity + Mr = Function(V, val=r.dat) + + # Check residual + for a, b in zip(Mr.subfunctions, c.subfunctions): + assert np.allclose(a.dat.data, b.dat.data, rtol=1e-14) + + +def test_function_riesz_representation(f): + # Get a Function + assert isinstance(f, Function) + + V = f.function_space() + u = TrialFunction(V) + v = TestFunction(V) + + # Define Riesz maps + riesz_maps = {'L2': inner(u, v) * dx, + 'H1': (inner(u, v) + inner(grad(u), grad(v))) * dx, + 'l2': None} + + # Check Riesz representation for each map + for riesz_map, mass in riesz_maps.items(): + + # Get Riesz representation of f + r = f.riesz_representation(riesz_map=riesz_map) + + assert isinstance(r, Cofunction) + assert r.function_space() == V.dual() + + if mass: + M = assemble(mass) + Mf = Function(V) + with f.dat.vec_ro as v_vec: + with Mf.dat.vec as res_vec: + M.petscmat.mult(v_vec, res_vec) + else: + # l2 mass matrix is identity + Mf = Cofunction(V.dual(), val=f.dat) + + # Check residual + for a, b in zip(Mf.subfunctions, r.subfunctions): + assert np.allclose(a.dat.data, b.dat.data, rtol=1e-14) + + +def helmholtz(r, quadrilateral=False, degree=2, mesh=None): + # Create mesh and define function space + if mesh is None: + mesh = UnitSquareMesh(2 ** r, 2 ** r, quadrilateral=quadrilateral) + V = FunctionSpace(mesh, "CG", degree) + + # Define variational problem + lmbda = 1 + u = TrialFunction(V) + v = TestFunction(V) + f = Function(V) + x = SpatialCoordinate(mesh) + f.interpolate((1+8*pi*pi)*cos(x[0]*pi*2)*cos(x[1]*pi*2)) + a = (inner(grad(u), grad(v)) + lmbda * inner(u, v)) * dx + + assembled_matrix = assemble(a) + preassemble_action = assemble(action(a, f)) + postassemble_action = assemble(action(assembled_matrix, f)) + + assert np.allclose(preassemble_action.M.values, postassemble_action.M.values, rtol=1e-14) diff --git a/tests/regression/test_bcs.py b/tests/regression/test_bcs.py index 24c9d56b46..1b977c431a 100644 --- a/tests/regression/test_bcs.py +++ b/tests/regression/test_bcs.py @@ -299,8 +299,10 @@ def test_mixed_bcs(diagonal): def test_bcs_rhs_assemble(a, V): bcs = [DirichletBC(V, 1.0, 1), DirichletBC(V, 2.0, 3)] b1 = assemble(a) + b1_func = b1.riesz_representation(riesz_map="l2") for bc in bcs: - bc.apply(b1) + bc.apply(b1_func) + b1.assign(b1_func.riesz_representation(riesz_map="l2")) b2 = assemble(a, bcs=bcs) assert np.allclose(b1.dat.data, b2.dat.data) diff --git a/tests/regression/test_linear_solver_change_bc.py b/tests/regression/test_linear_solver_change_bc.py index 987d23044a..424a4148a9 100644 --- a/tests/regression/test_linear_solver_change_bc.py +++ b/tests/regression/test_linear_solver_change_bc.py @@ -16,7 +16,7 @@ def test_linear_solver_change_bc(): bc = DirichletBC(V, bcval, "on_boundary") A = assemble(a, bcs=bc) - b = Function(V) + b = Cofunction(V.dual()) solver = LinearSolver(A) diff --git a/tests/regression/test_quadrature.py b/tests/regression/test_quadrature.py index 7054d72e41..6ec5f8b855 100644 --- a/tests/regression/test_quadrature.py +++ b/tests/regression/test_quadrature.py @@ -8,7 +8,7 @@ def test_hand_specified_quadrature(): a = conj(v) * dx - norm_q0 = norm(assemble(a, form_compiler_parameters={'quadrature_degree': 0})) - norm_q2 = norm(assemble(a, form_compiler_parameters={'quadrature_degree': 2})) + a_q0 = assemble(a, form_compiler_parameters={'quadrature_degree': 0}) + a_q2 = assemble(a, form_compiler_parameters={'quadrature_degree': 2}) - assert norm_q0 != norm_q2 + assert not np.allclose(a_q0.dat.data, a_q2.dat.data) diff --git a/tests/regression/test_solving_interface.py b/tests/regression/test_solving_interface.py index 448674f935..0493de99d5 100644 --- a/tests/regression/test_solving_interface.py +++ b/tests/regression/test_solving_interface.py @@ -219,3 +219,22 @@ def test_constant_jacobian_lvs(): lvs.solve() assert not (norm(assemble(out*5 - f)) < 2e-7) + + +def test_solve_cofunction_rhs(): + mesh = UnitSquareMesh(10, 10) + V = FunctionSpace(mesh, "CG", 1) + + u = TrialFunction(V) + v = TestFunction(V) + a = inner(u, v) * dx + + L = Cofunction(V.dual()) + L.vector()[:] = 1. + + w = Function(V) + solve(a == L, w) + + Aw = assemble(action(a, w)) + assert isinstance(Aw, Cofunction) + assert np.allclose(Aw.dat.data_ro, L.dat.data_ro) diff --git a/tests/slate/test_assemble_tensors.py b/tests/slate/test_assemble_tensors.py index d972e3588d..ac6b4663b6 100644 --- a/tests/slate/test_assemble_tensors.py +++ b/tests/slate/test_assemble_tensors.py @@ -107,7 +107,7 @@ def rank_two_tensor(mass): def test_tensor_action(mass, f): V = assemble(Tensor(mass) * AssembledVector(f)) ref = assemble(action(mass, f)) - assert isinstance(V, Function) + assert isinstance(V, Cofunction) assert np.allclose(V.dat.data, ref.dat.data, rtol=1e-14) @@ -115,13 +115,13 @@ def test_sum_tensor_actions(mass, f, g): V = assemble(Tensor(mass) * AssembledVector(f) + Tensor(0.5*mass) * AssembledVector(g)) ref = assemble(action(mass, f) + action(0.5*mass, g)) - assert isinstance(V, Function) + assert isinstance(V, Cofunction) assert np.allclose(V.dat.data, ref.dat.data, rtol=1e-14) def test_assemble_vector(rank_one_tensor): V = assemble(rank_one_tensor) - assert isinstance(V, Function) + assert isinstance(V, Cofunction) assert np.allclose(V.dat.data, assemble(rank_one_tensor.form).dat.data, rtol=1e-14) diff --git a/tests/slate/test_unaryop_precedence.py b/tests/slate/test_unaryop_precedence.py index 4168ff7ba9..3757503a8a 100644 --- a/tests/slate/test_unaryop_precedence.py +++ b/tests/slate/test_unaryop_precedence.py @@ -1,5 +1,4 @@ from firedrake import * -import numpy def test_unary_minus(): @@ -21,6 +20,8 @@ def test_unary_minus(): expr = action(A, uh) - B - assert numpy.allclose(norm(assemble(expr)), 0) + assembled_expr = assemble(expr) + assert assembled_expr.dat.norm < 1e-9 - assert numpy.allclose(norm(assemble(-expr)), 0) + assembled_expr = assemble(-expr) + assert assembled_expr.dat.norm < 1e-9