Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement snapshotting for the acoustic wave equation #2474

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

malfarhan7
Copy link

Implement snapshotting to save snapshots of the forward wavefield used to compute the gradient to reduce memory usage.

Copy link
Contributor

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

I have left some comments as it needs some changes to be mergeable. Some kind of test also needs to be added so that it is maintainable.

# Build operator equations
equations = eqn + src_term + rec_term

if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be wrapped into a utility function as it's duplicated below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a function to construct usnaps.

nsnaps = (geometry.nt + factor - 1) // factor
time_subsampled = ConditionalDimension(
't_sub', parent=model.grid.time_dim, factor=factor)
usnaps = TimeFunction(name='usnaps', grid=model.grid,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have u with full time saved line 135 you can't have both

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

name='Forward', **kwargs)
op = Operator(equations, subs=model.spacing_map, name='Forward', **kwargs)
if usnaps is not None:
return op, usnaps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the operator build cannot return objects like that. This is an abstract operator with placeholders that might not be correct for runtime.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The operator build only returns op now.


if factor is not None:
# Condition to apply gradient update only at snapshot times
condition = Eq(time % factor, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you don't need that usnap already contains the conditon

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
time_order=2, space_order=space_order)
if kernel == 'OT2':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary duplicate, u contains the information you should not need separate cases for gradient_update

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. No cases are used.

@@ -1,6 +1,6 @@
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
from devitofwi.devito.acoustic.operators import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I did not catch it.

@@ -108,12 +111,24 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
model = model or self.model
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))
# Get the operator
op_fwd = self.op_fwd(save=save, factor=factor)
# Prepare parameters for operator apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know what this is for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, usnap needs to be create here like u then passed as argument

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. usnaps is created now.

op_args['usnaps'] = usnaps
summary = op.apply(**op_args)

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need if else only kwargs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -209,8 +236,17 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt,
**kwargs)
if factor is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not needed, input u should contain all metada needed

Copy link
Author

@malfarhan7 malfarhan7 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@mloubout mloubout added the examples examples label Oct 28, 2024
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@malfarhan7
Copy link
Author

Hi Mathias, thank you for your feedback. I have reviewed and cleaned the code to the best of my understanding. I have included a notebook to compare computing the FWI gradient with and without snapshotting and two scripts to calculate the memory usage of both methods. After updating the code, the memory usage for calculating the gradient with snapshotting is more than twice that of the older code version. This reduced memory usage (I guess) because I was passing 'usnaps' with the operator (which is not good practice). I am wondering, is it possible to improve the code more to reduce the memory usage?

time_order=2, space_order=space_order)
rec = geometry.rec

s = model.grid.stepping_dim.spacing
eqn = iso_stencil(v, model, kernel, forward=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert change, pep8 violation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

receivers = rec.inject(field=v.backward, expr=rec * s**2 / m)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save
else None, time_order=2, space_order=space_order)
v = TimeFunction(name='v', grid=model.grid, save=None,
if factor: # Apply the imaging condition at the snapshots of the full wavefield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave a blank line between the grad = and this if factor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the comment inside the body of the if

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

v = TimeFunction(name='v', grid=model.grid, save=None,
if factor: # Apply the imaging condition at the snapshots of the full wavefield
u = create_snapshot_time_function(model, 'u', geometry, space_order, factor)
else:# Apply the imaging condition at every time step of the full wavefield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the comment inside the body of the else

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -90,30 +91,38 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
The time-constant velocity.
save : bool, optional
Whether or not to save the entire (unrolled) wavefield.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import scipy
from memory_profiler import memory_usage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports from stdlib at the very top

then blank line

then imports from third parties (eg scipy)

then blank line

then examples imports
then devito imports

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

nsnaps = (geometry.nt + factor - 1) // factor
time_subsampled = ConditionalDimension('t_sub',
parent=model.grid.time_dim, factor=factor)
u_ = TimeFunction(name=name, grid=model.grid,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"usnaps" for homogeneity

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"""
m = model.m

# Create symbols for forward wavefield, source and receivers
u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
save=geometry.nt if save and factor is None else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of this composite conditional involving both save and factor, which is also repeated across other modules

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the conditional statement outside of the u definition but I do not know if there is a better way to avoid the composite conditional statement.

# Substitute spacing terms to reduce flops
return Operator(eqn + receivers + [gradient_update], subs=model.spacing_map,
name='Gradient', **kwargs)
name='Gradient', **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor: # Return snapshots of the forward wavefield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since factor is passed down to op_fwd, I don't think we need the extra if factor : .... else: ... here, somehow it should be avoided and/or it's avoidable

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to do this, as the code did not run correctly without the condition? I made the return statement conditional so as not to break people's code, so I kept the number of returned objects at three.

@malfarhan7
Copy link
Author

Hi Fabio, thank you for your feedback. I have reviewed and cleaned the code to the best of my understanding.

"""
m = model.m

# Create symbols for forward wavefield, source and receivers
save_value = geometry.nt if save and factor is None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the code more homogeneous and easier to read, what we could do, imho, is the following:

  • drop save from this TimeFunction (which is just legacy behaviour...)
  • systematically create usnaps and the corresponding equation
  • with an inf (aka sys.maxint ?) factor, no snapshots will be saved at runtime
  • Tweak here adding a line if self.size == 0: return, so that Devito avoids allocating memory entirely if pointless

In my opinion, this will dramatically clean up the code here.

At the moment, the proliferation of if factor logic is still affecting maintainability too much

@mloubout is on vacation ATM, but would be useful to hear from his thoughts about this matter

@@ -107,8 +109,60 @@ def iso_stencil(field, model, kernel, **kwargs):
return eqns


def create_snapshot_time_function(model, name, geometry, space_order, factor=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lift that into seismic/utils.py or some new file so it can be easily reused from other pdes. It can probably be done directly from the wavefield too i,e

def create_snapshot_time_function(u, nsnap):

    name = f"{u.name}_save"
    grid = u.grid
    
    ....
    

Makes it also decoupled from those geometry/model. THis would also allow to directly return the equation Eq(u,usave) as well



# Compute residual
def compute_residual(residual, dobs, dsyn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a copy paste for the fwi example, import the function from there

@@ -0,0 +1,118 @@
import argparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it has a lot of common with examples/seismic/inversion it would be better to avoid duplications and update that example instead.

shape = (601, 221)
spacing = (15.0, 15.0)
origin = (0.0, 0.0)
vel_path = '../../../devito/data/Marm.bin'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be made available for CI somehow

ff, grad = fwi_gradient(mode, model, solver, geometry,
source_locations, model0.vp, factor)
mem_usage = memory_usage()[0]
print(f"Memory usage at the end of gradient ({mode} mode): {mem_usage:.2f} MiB")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the goal of this check is here. the gradient is computed in as separate function so all the memory is freed. Showing memory usage difference would need to do this between forward and gradient for a single shot

residual = Receiver(name='residual', grid=model.grid,
time_range=geometry.time_axis,
coordinates=geometry.rec_positions)
d_obs = Receiver(name='d_obs', grid=model.grid,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

geometry.new_rec(name='d_obs') same for others

coordinates=geometry.rec_positions)
objective = 0.0
for i in range(nshots):
geometry.src_positions[0, :] = source_locations[i, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modifying geometry isn't good here. Instead get a source src=geometry.src and on;y modify the src coordinates. (I know this is just taken from the FWI tutorial but it's an old tutorial that needs updates)

for i in range(nshots):
geometry.src_positions[0, :] = source_locations[i, :]
solver.forward(vp=model.vp, rec=d_obs)
save_value = True if mode == "full" else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be just a bool directly at argparse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples examples
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants