-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimize duplication in map_blocks
task graph
#8412
Conversation
Thanks a lot @dcherian ! (I don't have enough context to know how severe the change the parallelism is. I do really appreciate that |
We do filter the indexes. The problem is that the filtered index values are duplicated a very large number of times for the calculation. The duplication allows the graph to be embarassingly parallel. And then we include them a second time to enable nice error messages. |
Ah right, yes. I confirmed that — the size difference scales by
Defer to you on how this affects dask stability... |
@fjetter do you think dask/distributed will handle the change in graph topology in the OP gracefully? |
At least if the idea of "working around scheduling issues" is to forcefully flatten the graph to a purely embarrassingly parallel workload, this property is now gone but I believe you are still fine. I am not super familiar with xarray datasets so I am doing a bit of guesswork here. IIUC this example dataset has three coordinates / indices Then there is also the If this is all correct, then yes, this is handled gracefully by dask (at least with the latest release, haven't checked older ones) import xarray as xr
from dask.utils import key_split
from dask.order import diagnostics
from dask.base import collections_to_dsk
da = xr.tutorial.load_dataset('air_temperature')
dsk = collections_to_dsk([da.chunk(lat=1, lon=1).map_blocks(lambda x: x)])
diag, _ = diagnostics(dsk)
ages_data_tasks = [
v.age == 1
for k, v in diag.items()
if key_split(k).startswith('xarray-air')
]
assert ages_data_tasks
assert all(ages_data_tasks) Age refers to the number of "ticks / time steps" this task survives.
If those indices are truly always numpy arrays, I would probably suggest to just slice them to whatever size they need for the given task and embed them, keeping the embarrassingly parallel workload. I think I do not understand this problem sufficiently, It feels like I'm missing something. |
Broadcasting means that the tiny shards get duplicated a very large number of times in the graph. The OP was prompted by a 1GB task graph. |
ba52ec0
to
a4bda14
Compare
* main: Adapt map_blocks to use new Coordinates API (pydata#8560) add xeofs to ecosystem.rst (pydata#8561) Offer a fixture for unifying DataArray & Dataset tests (pydata#8533) Generalize cumulative reduction (scan) to non-dask types (pydata#8019)
* upstream/main: Faster encoding functions. (pydata#8565) ENH: vendor SerializableLock from dask and use as default backend lock, adapt tests (pydata#8571) Silence a bunch of CachingFileManager warnings (pydata#8584) Bump actions/download-artifact from 3 to 4 (pydata#8556) Minimize duplication in `map_blocks` task graph (pydata#8412) [pre-commit.ci] pre-commit autoupdate (pydata#8578) ignore a `DeprecationWarning` emitted by `seaborn` (pydata#8576) Fix mypy type ignore (pydata#8564) Support for the new compression arguments. (pydata#7551) FIX: reverse index output of bottleneck move_argmax/move_argmin functions (pydata#8552)
Builds on #8560
.map_blocks
with many chunks can be huge #8409whats-new.rst
cc @max-sixty
This is a quick attempt. I think we can generalize this to minimize duplication.
The downside is that the graphs are not totally embarrassingly parallel any more.
This PR:
vs main: