@@ -29,6 +29,7 @@ class Solver(Enum):
29
29
"""Enum for solver types."""
30
30
FORWARD = 0
31
31
ADJOINT = 1
32
+ TLM = 2
32
33
33
34
34
35
class GenericSolveBlock (Block ):
@@ -681,8 +682,11 @@ def _adjoint_solve(self, dJdu, compute_bdy):
681
682
def _ad_assign_map (self , form , solver ):
682
683
if solver == Solver .FORWARD :
683
684
count_map = self ._ad_solvers ["forward_nlvs" ]._problem ._ad_count_map
684
- else :
685
+ elif solver == Solver . ADJOINT :
685
686
count_map = self ._ad_solvers ["adjoint_lvs" ]._problem ._ad_count_map
687
+ elif solver == Solver .TLM :
688
+ count_map = self ._ad_solvers ["tlm_lvs" ]._problem ._ad_count_map
689
+
686
690
assign_map = {}
687
691
form_ad_count_map = dict ((count_map [coeff ], coeff )
688
692
for coeff in form .coefficients ())
@@ -717,9 +721,13 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
717
721
problem = self ._ad_solvers ["forward_nlvs" ]._problem
718
722
self ._ad_assign_coefficients (problem .F , solver )
719
723
self ._ad_assign_coefficients (problem .J , solver )
720
- else :
724
+ elif solver == Solver . ADJOINT :
721
725
self ._ad_assign_coefficients (
722
726
self ._ad_solvers ["adjoint_lvs" ]._problem .J , solver )
727
+ elif solver == Solver .TLM :
728
+ self ._ad_assign_coefficients (
729
+ self ._ad_solvers ["tlm_lvs" ]._problem .J , solver
730
+ )
723
731
724
732
def prepare_evaluate_adj (self , inputs , adj_inputs , relevant_dependencies ):
725
733
compute_bdy = self ._should_compute_boundary_adjoint (
@@ -796,6 +804,59 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
796
804
797
805
return dFdm
798
806
807
+ def evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx ,
808
+ prepared = None ):
809
+ F_form = prepared ["form" ]
810
+ dFdu = prepared ["dFdu" ]
811
+
812
+ bcs = []
813
+ dFdm = 0.
814
+ for block_variable in self .get_dependencies ():
815
+ tlm_value = block_variable .tlm_value
816
+ c = block_variable .output
817
+ c_rep = block_variable .saved_output
818
+
819
+ if isinstance (c , firedrake .DirichletBC ):
820
+ if tlm_value is None :
821
+ bcs .append (c .reconstruct (g = 0 ))
822
+ else :
823
+ bcs .append (tlm_value )
824
+ continue
825
+ elif isinstance (c , firedrake .MeshGeometry ):
826
+ X = firedrake .SpatialCoordinate (c )
827
+ c_rep = X
828
+
829
+ if tlm_value is None :
830
+ continue
831
+
832
+ if c == self .func and not self .linear :
833
+ continue
834
+
835
+ dFdm += firedrake .derivative (- F_form , c_rep , tlm_value )
836
+
837
+ if isinstance (dFdm , float ):
838
+ v = dFdu .arguments ()[0 ]
839
+ dFdm = firedrake .inner (
840
+ firedrake .Constant (numpy .zeros (v .ufl_shape )), v
841
+ ) * firedrake .dx
842
+
843
+ dFdm = ufl .algorithms .expand_derivatives (dFdm )
844
+ dFdm = firedrake .assemble (dFdm )
845
+
846
+ # XXX I dunno how this works
847
+ self ._ad_solver_replace_forms (Solver .TLM )
848
+ self ._ad_solvers ["tlm_lvs" ].invalidate_jacobian ()
849
+ # update RHS
850
+ self ._ad_solvers ["tlm_lvs" ]._problem .F ._components [1 ].assign (dFdm )
851
+
852
+ self ._ad_solvers ["tlm_lvs" ].solve ()
853
+ return self ._ad_solvers ["tlm_lvs" ]._problem .u
854
+ # return self._assemble_and_solve_tlm_eq(
855
+ # firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
856
+ # dFdm, dudm, bcs
857
+ # )
858
+
859
+
799
860
800
861
class ProjectBlock (SolveVarFormBlock ):
801
862
def __init__ (self , v , V , output , bcs = [], * args , ** kwargs ):
0 commit comments