-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement snapshotting for the acoustic wave equation #2474
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution!
I have left some comments as it needs some changes to be mergeable. Some kind of test also needs to be added so that it is maintainable.
# Build operator equations | ||
equations = eqn + src_term + rec_term | ||
|
||
if factor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be wrapped into a utility function as it's duplicated below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've created a function to construct usnaps.
nsnaps = (geometry.nt + factor - 1) // factor | ||
time_subsampled = ConditionalDimension( | ||
't_sub', parent=model.grid.time_dim, factor=factor) | ||
usnaps = TimeFunction(name='usnaps', grid=model.grid, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You still have u
with full time saved line 135 you can't have both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
name='Forward', **kwargs) | ||
op = Operator(equations, subs=model.spacing_map, name='Forward', **kwargs) | ||
if usnaps is not None: | ||
return op, usnaps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No the operator build cannot return objects like that. This is an abstract operator with placeholders that might not be correct for runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. The operator build only returns op now.
|
||
if factor is not None: | ||
# Condition to apply gradient update only at snapshot times | ||
condition = Eq(time % factor, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No you don't need that usnap
already contains the conditon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
u = TimeFunction(name='u', grid=model.grid, | ||
save=geometry.nt if save else None, | ||
time_order=2, space_order=space_order) | ||
if kernel == 'OT2': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary duplicate, u
contains the information you should not need separate cases for gradient_update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. No cases are used.
@@ -1,6 +1,6 @@ | |||
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver | |||
from devito.tools import memoized_meth | |||
from examples.seismic.acoustic.operators import ( | |||
from devitofwi.devito.acoustic.operators import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I did not catch it.
@@ -108,12 +111,24 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | |||
model = model or self.model | |||
# Pick vp from model unless explicitly provided | |||
kwargs.update(model.physical_params(**kwargs)) | |||
# Get the operator | |||
op_fwd = self.op_fwd(save=save, factor=factor) | |||
# Prepare parameters for operator apply |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know what this is for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
dt=kwargs.pop('dt', self.dt), **kwargs) | ||
|
||
return rec, u, summary | ||
if factor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, usnap needs to be create here like u
then passed as argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. usnaps is created now.
op_args['usnaps'] = usnaps | ||
summary = op.apply(**op_args) | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't need if else only kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@@ -209,8 +236,17 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |||
wrp.apply_forward() | |||
summary = wrp.apply_reverse() | |||
else: | |||
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt, | |||
**kwargs) | |||
if factor is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, not needed, input u
should contain all metada needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
…apshot_acoustic
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Hi Mathias, thank you for your feedback. I have reviewed and cleaned the code to the best of my understanding. I have included a notebook to compare computing the FWI gradient with and without snapshotting and two scripts to calculate the memory usage of both methods. After updating the code, the memory usage for calculating the gradient with snapshotting is more than twice that of the older code version. This reduced memory usage (I guess) because I was passing 'usnaps' with the operator (which is not good practice). I am wondering, is it possible to improve the code more to reduce the memory usage? |
time_order=2, space_order=space_order) | ||
rec = geometry.rec | ||
|
||
s = model.grid.stepping_dim.spacing | ||
eqn = iso_stencil(v, model, kernel, forward=False) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert change, pep8 violation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
receivers = rec.inject(field=v.backward, expr=rec * s**2 / m) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save | ||
else None, time_order=2, space_order=space_order) | ||
v = TimeFunction(name='v', grid=model.grid, save=None, | ||
if factor: # Apply the imaging condition at the snapshots of the full wavefield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave a blank line between the grad =
and this if factor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the comment inside the body of the if
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
v = TimeFunction(name='v', grid=model.grid, save=None, | ||
if factor: # Apply the imaging condition at the snapshots of the full wavefield | ||
u = create_snapshot_time_function(model, 'u', geometry, space_order, factor) | ||
else:# Apply the imaging condition at every time step of the full wavefield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the comment inside the body of the else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@@ -90,30 +91,38 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | |||
The time-constant velocity. | |||
save : bool, optional | |||
Whether or not to save the entire (unrolled) wavefield. | |||
factor : int, optional | |||
Downsampling factor to save snapshots of the wavefield. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
import matplotlib.pyplot as plt | ||
from scipy.ndimage import gaussian_filter | ||
import scipy | ||
from memory_profiler import memory_usage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imports from stdlib at the very top
then blank line
then imports from third parties (eg scipy)
then blank line
then examples imports
then devito imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
nsnaps = (geometry.nt + factor - 1) // factor | ||
time_subsampled = ConditionalDimension('t_sub', | ||
parent=model.grid.time_dim, factor=factor) | ||
u_ = TimeFunction(name=name, grid=model.grid, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"usnaps" for homogeneity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
""" | ||
m = model.m | ||
|
||
# Create symbols for forward wavefield, source and receivers | ||
u = TimeFunction(name='u', grid=model.grid, | ||
save=geometry.nt if save else None, | ||
save=geometry.nt if save and factor is None else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a big fan of this composite conditional involving both save
and factor
, which is also repeated across other modules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the conditional statement outside of the u
definition but I do not know if there is a better way to avoid the composite conditional statement.
# Substitute spacing terms to reduce flops | ||
return Operator(eqn + receivers + [gradient_update], subs=model.spacing_map, | ||
name='Gradient', **kwargs) | ||
name='Gradient', **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re-indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
dt=kwargs.pop('dt', self.dt), **kwargs) | ||
|
||
return rec, u, summary | ||
if factor: # Return snapshots of the forward wavefield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since factor is passed down to op_fwd
, I don't think we need the extra if factor : .... else: ...
here, somehow it should be avoided and/or it's avoidable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better way to do this, as the code did not run correctly without the condition? I made the return statement conditional so as not to break people's code, so I kept the number of returned objects at three.
Hi Fabio, thank you for your feedback. I have reviewed and cleaned the code to the best of my understanding. |
""" | ||
m = model.m | ||
|
||
# Create symbols for forward wavefield, source and receivers | ||
save_value = geometry.nt if save and factor is None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the code more homogeneous and easier to read, what we could do, imho, is the following:
- drop
save
from this TimeFunction (which is just legacy behaviour...) - systematically create
usnaps
and the corresponding equation - with an
inf
(akasys.maxint
?)factor
, no snapshots will be saved at runtime - Tweak here adding a line
if self.size == 0: return
, so that Devito avoids allocating memory entirely if pointless
In my opinion, this will dramatically clean up the code here.
At the moment, the proliferation of if factor
logic is still affecting maintainability too much
@mloubout is on vacation ATM, but would be useful to hear from his thoughts about this matter
@@ -107,8 +109,60 @@ def iso_stencil(field, model, kernel, **kwargs): | |||
return eqns | |||
|
|||
|
|||
def create_snapshot_time_function(model, name, geometry, space_order, factor=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would lift that into seismic/utils.py
or some new file so it can be easily reused from other pdes. It can probably be done directly from the wavefield too i,e
def create_snapshot_time_function(u, nsnap):
name = f"{u.name}_save"
grid = u.grid
....
Makes it also decoupled from those geometry/model. THis would also allow to directly return the equation Eq(u,usave)
as well
|
||
|
||
# Compute residual | ||
def compute_residual(residual, dobs, dsyn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a copy paste for the fwi example, import the function from there
@@ -0,0 +1,118 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it has a lot of common with examples/seismic/inversion
it would be better to avoid duplications and update that example instead.
shape = (601, 221) | ||
spacing = (15.0, 15.0) | ||
origin = (0.0, 0.0) | ||
vel_path = '../../../devito/data/Marm.bin' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be made available for CI somehow
ff, grad = fwi_gradient(mode, model, solver, geometry, | ||
source_locations, model0.vp, factor) | ||
mem_usage = memory_usage()[0] | ||
print(f"Memory usage at the end of gradient ({mode} mode): {mem_usage:.2f} MiB") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the goal of this check is here. the gradient is computed in as separate function so all the memory is freed. Showing memory usage difference would need to do this between forward and gradient for a single shot
residual = Receiver(name='residual', grid=model.grid, | ||
time_range=geometry.time_axis, | ||
coordinates=geometry.rec_positions) | ||
d_obs = Receiver(name='d_obs', grid=model.grid, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
geometry.new_rec(name='d_obs')
same for others
coordinates=geometry.rec_positions) | ||
objective = 0.0 | ||
for i in range(nshots): | ||
geometry.src_positions[0, :] = source_locations[i, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modifying geometry isn't good here. Instead get a source src=geometry.src
and on;y modify the src
coordinates. (I know this is just taken from the FWI tutorial but it's an old tutorial that needs updates)
for i in range(nshots): | ||
geometry.src_positions[0, :] = source_locations[i, :] | ||
solver.forward(vp=model.vp, rec=d_obs) | ||
save_value = True if mode == "full" else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be just a bool directly at argparse
Implement snapshotting to save snapshots of the forward wavefield used to compute the gradient to reduce memory usage.