Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint derivative taking a bit long #1621

Open
anezkap opened this issue Mar 5, 2020 · 5 comments
Open

Adjoint derivative taking a bit long #1621

anezkap opened this issue Mar 5, 2020 · 5 comments

Comments

@anezkap
Copy link

anezkap commented Mar 5, 2020

Hi,
I am solving a two equation problem in Firedrake and I was a bit concerned about the disproportional amount of time it takes to compute the derivative of my reduced functional Jhat.

The 1D problem I'm trying to solve is as follows:
Screenshot 2020-03-05 at 13 05 34

The equation for c and solving method is basically the same as in the DG advection equation with upwinding tutorial, k, k_2 are constants, and the control here is c_in.

Solving for c and q takes only couple of seconds, but computing Jhat.derivative() takes around 6 minutes.

Does this look normal, or is there a problem in my code and/or a way how to solve this faster?

Thank you for your help!

from firedrake import *
from firedrake_adjoint import *

# Set up the mesh
mesh = UnitIntervalMesh(40)

# Set up the function spaces
Vec = VectorFunctionSpace(mesh, "CG", 1)
V_c = FunctionSpace(mesh, "DG", 1)
V_q = FunctionSpace(mesh, "DG", 0)
W = V_c*V_q

# Get the spatial coordinate for x and set constant velocity with static boundary conditions
x, = SpatialCoordinate(mesh)

velocity = as_vector((1, ))
u = Function(Vec).interpolate(velocity)
c_in = Constant(1.0)

bcs = [DirichletBC(W.sub(0), c_in, 1)]

# Set the initial condition
f = Function(W)
with stop_annotating():
    c, q = f.split()
    q.assign(1.0)


# Set time T, step dt
T = 2
dt = T/600
dtc = Constant(dt)

# Set the left hand side of our equation
dc_trial, dq_trial = TrialFunctions(W)
phi, psi = TestFunctions(W)
a = phi*dc_trial*dx + psi*dq_trial*dx

# We define ``n`` to be the built-in ``FacetNormal`` object; a unit normal vector
# that can be used in integrals over exterior and interior facets.  We next define
# ``un`` to be an object which is equal to :math:`\vec{u}\cdot\vec{n}` if this is
# positive, and zero if this is negative. This will be useful in the upwind terms.
n = FacetNormal(mesh)
un = 0.5*(dot(u, n) + abs(dot(u, n)))

k = 0.8
k2 = 0.1

# Right-hand side
L1 = dtc*(c*div(phi*u)*dx
          - conditional(dot(u, n) < 0, phi*dot(u, n)*c_in, 0.0)*ds
          - conditional(dot(u, n) > 0, phi*dot(u, n)*c, 0.0)*ds
          - (phi('+') - phi('-'))*(un('+')*c('+') - un('-')*c('-'))*dS
          - k*phi*q*c*dx
          - k2*psi*q*c*dx)

# Runge-Kutta
f1 = Function(W); f2 = Function(W)
L2 = replace(L1, {c: split(f1)[0], q: split(f1)[1]}); L3 = replace(L1, {c: split(f2)[0], q: split(f2)[1]})

# We now declare a variable to hold the temporary increments at each stage.
df = Function(W)

# We make use of the ``LinearVariationalProblem`` and
# ``LinearVariationalSolver`` objects for each of our Runge-Kutta stages.
params = {'ksp_type': 'preonly', 'pc_type': 'bjacobi', 'sub_pc_type': 'ilu', 'mat_type': 'aij'}
prob1 = LinearVariationalProblem(a, L1, df, bcs=bcs)
solv1 = LinearVariationalSolver(prob1, solver_parameters=params)
prob2 = LinearVariationalProblem(a, L2, df, bcs=bcs)
solv2 = LinearVariationalSolver(prob2, solver_parameters=params)
prob3 = LinearVariationalProblem(a, L3, df, bcs=bcs)
solv3 = LinearVariationalSolver(prob3, solver_parameters=params)

# Run the time loop with three Runge-Kutta stages, and write the results
# into the results list
t = 0.0
step = 0

with stop_annotating():
    c_, q_ = f.split()
    results = [[Function(c_)],[Function(q)]]

while t < T - 0.5*dt:
    solv1.solve()
    f1.assign(f + df)

    solv2.solve()
    f2.assign(0.75*f + 0.25*(f1 + df))

    solv3.solve()
    f.assign((1.0/3.0)*f + (2.0/3.0)*(f2 + df))

    with stop_annotating():
        c_, q_ = f.split()
        results[0].append(Function(c_))
        results[1].append(Function(q_))

    step += 1
    t += dt

# Set up control and reduced functional Jhat
c, q = split(f)
J = assemble(c*ds(2))
m = Control(c_in)
Jhat = ReducedFunctional(J, m)


d = Jhat.derivative()
print(d.dat.data)
@salazardetroya
Copy link
Contributor

It is likely that the the adjoint solver is using an iterative method and taking many iterations. The reason is because the solver parameters from the forward solve are not passed to the adjoint solve. Please see this issue in the pyadjoint repo and use the patch suggested there.

@anezkap
Copy link
Author

anezkap commented Mar 10, 2020

Thank you for your help. I though that this (taking many iterations without having the solver parameters from the forward solve) might be the problem as well, but unfortunately the suggested patch did not help.
But it's okay. I do not really need my code to be super fast at the moment.

@florianwechsung
Copy link
Contributor

Just to check, did you verify that the options are still not passed or did you just observe that the code is still slow? In the latter case, it may be that there is a separate issue. I suspect the problem is that pyadjoint solves all adjoint equations via the solve(...) interface, which has a fair bit of overhead.

@anezkap
Copy link
Author

anezkap commented May 25, 2020

Hi again
I'm sorry I left this open for so long. Anyway, turns out that I could actually use my code being a bit faster, even though it is not totally crucial.
I only observed that my code is still slow, I did not verify whether the options are passed or not. I'm not totally sure how to do that - could you guide me a bit about how I can verify it, please? Thanks a lot

@wence-
Copy link
Contributor

wence- commented Sep 2, 2020

We recently (today) merged some changes (#1804) that make this kind of use faster. Can you update and check?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants