Skip to content

Commit ea7d35c

Browse files
hendrikmakaitphofl
andauthored
Concatenate small input chunks before P2P rechunking (#8832)
Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
1 parent c073797 commit ea7d35c

File tree

2 files changed

+248
-22
lines changed

2 files changed

+248
-22
lines changed

distributed/shuffle/_rechunk.py

Lines changed: 158 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696

9797
from __future__ import annotations
9898

99+
import math
99100
import mmap
100101
import os
101102
from collections import defaultdict
@@ -111,7 +112,7 @@
111112
)
112113
from concurrent.futures import ThreadPoolExecutor
113114
from dataclasses import dataclass
114-
from itertools import product
115+
from itertools import chain, product
115116
from pathlib import Path
116117
from typing import TYPE_CHECKING, Any, NamedTuple, cast
117118

@@ -124,6 +125,7 @@
124125
from dask.highlevelgraph import HighLevelGraph
125126
from dask.layers import Layer
126127
from dask.typing import Key
128+
from dask.utils import parse_bytes
127129

128130
from distributed.core import PooledRPCCall
129131
from distributed.metrics import context_meter
@@ -220,7 +222,7 @@ def rechunk_p2p(
220222
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)
221223
from dask.array.core import new_da_object
222224

223-
prechunked = _calculate_prechunking(x.chunks, chunks)
225+
prechunked = _calculate_prechunking(x.chunks, chunks, x.dtype, block_size_limit)
224226
if prechunked != x.chunks:
225227
x = cast(
226228
"da.Array",
@@ -433,8 +435,140 @@ def _construct_graph(self) -> _T_LowLevelGraph:
433435

434436

435437
def _calculate_prechunking(
436-
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
438+
old_chunks: ChunkedAxes,
439+
new_chunks: ChunkedAxes,
440+
dtype: np.dtype,
441+
block_size_limit: int | None,
442+
) -> ChunkedAxes:
443+
"""Calculate how to perform the pre-rechunking step
444+
445+
During the pre-rechunking step, we
446+
1. Split input chunks along partial boundaries to make partials completely independent of one another
447+
2. Merge small chunks within partials to reduce the number of transfer tasks and corresponding overhead
448+
"""
449+
split_axes = _split_chunks_along_partial_boundaries(old_chunks, new_chunks)
450+
451+
# We can only determine how to concatenate chunks if we can calculate block sizes.
452+
has_nans = (any(math.isnan(y) for y in x) for x in old_chunks)
453+
454+
if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans):
455+
return tuple(tuple(chain(*axis)) for axis in split_axes)
456+
457+
if dtype is None or dtype.hasobject or dtype.itemsize == 0:
458+
return tuple(tuple(chain(*axis)) for axis in split_axes)
459+
460+
# We made sure that there are no NaNs in split_axes above
461+
return _concatenate_small_chunks(
462+
split_axes, old_chunks, new_chunks, dtype, block_size_limit # type: ignore[arg-type]
463+
)
464+
465+
466+
def _concatenate_small_chunks(
467+
split_axes: list[list[list[int]]],
468+
old_chunks: ChunkedAxes,
469+
new_chunks: ChunkedAxes,
470+
dtype: np.dtype,
471+
block_size_limit: int | None,
437472
) -> ChunkedAxes:
473+
"""Concatenate small chunks within partials.
474+
475+
By concatenating chunks within partials, we reduce the number of P2P transfer tasks and their
476+
corresponding overhead.
477+
478+
The algorithm used in this function is very similar to :func:`dask.array.rechunk.find_merge_rechunk`,
479+
the main difference is that we have to make sure only to merge chunks within partials.
480+
"""
481+
import numpy as np
482+
483+
block_size_limit = block_size_limit or dask.config.get("array.chunk-size")
484+
485+
if isinstance(block_size_limit, str):
486+
block_size_limit = parse_bytes(block_size_limit)
487+
488+
# Make it a number of elements
489+
block_size_limit //= dtype.itemsize
490+
491+
# We verified earlier that we do not have any NaNs
492+
largest_old_block = _largest_block_size(old_chunks) # type: ignore[arg-type]
493+
largest_new_block = _largest_block_size(new_chunks) # type: ignore[arg-type]
494+
block_size_limit = max([block_size_limit, largest_old_block, largest_new_block])
495+
496+
old_largest_width = [max(chain(*axis)) for axis in split_axes]
497+
new_largest_width = [max(c) for c in new_chunks]
498+
499+
# This represents how much each dimension increases (>1) or reduces (<1)
500+
# the graph size during rechunking
501+
graph_size_effect = {
502+
dim: len(new_axis) / sum(map(len, split_axis))
503+
for dim, (split_axis, new_axis) in enumerate(zip(split_axes, new_chunks))
504+
}
505+
506+
ndim = len(old_chunks)
507+
508+
# This represents how much each dimension increases (>1) or reduces (<1) the
509+
# largest block size during rechunking
510+
block_size_effect = {
511+
dim: new_largest_width[dim] / (old_largest_width[dim] or 1)
512+
for dim in range(ndim)
513+
}
514+
515+
# Our goal is to reduce the number of nodes in the rechunk graph
516+
# by concatenating some adjacent chunks, so consider dimensions where we can
517+
# reduce the # of chunks
518+
candidates = [dim for dim in range(ndim) if graph_size_effect[dim] <= 1.0]
519+
520+
# Concatenating along each dimension reduces the graph size by a certain factor
521+
# and increases memory largest block size by a certain factor.
522+
# We want to optimize the graph size while staying below the given
523+
# block_size_limit. This is in effect a knapsack problem, except with
524+
# multiplicative values and weights. Just use a greedy algorithm
525+
# by trying dimensions in decreasing value / weight order.
526+
def key(k: int) -> float:
527+
gse = graph_size_effect[k]
528+
bse = block_size_effect[k]
529+
if bse == 1:
530+
bse = 1 + 1e-9
531+
return (np.log(gse) / np.log(bse)) if bse > 0 else 0
532+
533+
sorted_candidates = sorted(candidates, key=key)
534+
535+
concatenated_axes: list[list[int]] = [[] for i in range(ndim)]
536+
537+
# Sim all the axes that are no candidates
538+
for i in range(ndim):
539+
if i in candidates:
540+
continue
541+
concatenated_axes[i] = list(chain(*split_axes[i]))
542+
543+
# We want to concatenate chunks
544+
for axis_index in sorted_candidates:
545+
concatenated_axis = concatenated_axes[axis_index]
546+
multiplier = math.prod(
547+
old_largest_width[:axis_index] + old_largest_width[axis_index + 1 :]
548+
)
549+
axis_limit = block_size_limit // multiplier
550+
551+
for partial in split_axes[axis_index]:
552+
current = partial[0]
553+
for chunk in partial[1:]:
554+
if (current + chunk) > axis_limit:
555+
concatenated_axis.append(current)
556+
current = chunk
557+
else:
558+
current += chunk
559+
concatenated_axis.append(current)
560+
old_largest_width[axis_index] = max(concatenated_axis)
561+
return tuple(tuple(axis) for axis in concatenated_axes)
562+
563+
564+
def _split_chunks_along_partial_boundaries(
565+
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
566+
) -> list[list[list[float]]]:
567+
"""Split the old chunks along the boundaries of partials, i.e., groups of new chunks that share the same inputs.
568+
569+
By splitting along the boundaries before rechunkin their input tasks become disjunct and each partial conceptually
570+
operates on an independent sub-array.
571+
"""
438572
from dask.array.rechunk import old_to_new
439573

440574
_old_to_new = old_to_new(old_chunks, new_chunks)
@@ -443,10 +577,13 @@ def _calculate_prechunking(
443577

444578
split_axes = []
445579

580+
# Along each axis, we want to figure out how we have to split input chunks in order to make
581+
# partials disjunct. We then group the resulting input chunks per partial before returning.
446582
for axis_index, slices in enumerate(partials):
447583
old_to_new_axis = _old_to_new[axis_index]
448584
old_axis = old_chunks[axis_index]
449585
split_axis = []
586+
partial_chunks = []
450587
for slice_ in slices:
451588
first_new_chunk = slice_.start
452589
first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0]
@@ -465,22 +602,28 @@ def _calculate_prechunking(
465602
chunk_size = last_old_slice.stop
466603
if first_old_slice.start != 0:
467604
chunk_size -= first_old_slice.start
468-
split_axis.append(chunk_size)
469-
continue
470-
471-
split_axis.append(first_chunk_size - first_old_slice.start)
472-
473-
split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk])
474-
475-
if last_old_slice.stop is not None:
476-
chunk_size = last_old_slice.stop
605+
partial_chunks.append(chunk_size)
477606
else:
478-
chunk_size = last_chunk_size
607+
partial_chunks.append(first_chunk_size - first_old_slice.start)
479608

480-
split_axis.append(chunk_size)
609+
partial_chunks.extend(old_axis[first_old_chunk + 1 : last_old_chunk])
481610

611+
if last_old_slice.stop is not None:
612+
chunk_size = last_old_slice.stop
613+
else:
614+
chunk_size = last_chunk_size
615+
616+
partial_chunks.append(chunk_size)
617+
split_axis.append(partial_chunks)
618+
partial_chunks = []
619+
if partial_chunks:
620+
split_axis.append(partial_chunks)
482621
split_axes.append(split_axis)
483-
return tuple(tuple(axis) for axis in split_axes)
622+
return split_axes
623+
624+
625+
def _largest_block_size(chunks: tuple[tuple[int, ...], ...]) -> int:
626+
return math.prod(map(max, chunks))
484627

485628

486629
def _split_partials(

distributed/shuffle/tests/test_rechunk.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,8 @@ async def test_rechunk_avoid_needless_chunking(c, s, *ws):
847847
x = da.ones(16, chunks=2)
848848
y = x.rechunk(8, method="p2p")
849849
dsk = y.__dask_graph__()
850-
assert len(dsk) <= 8 + 2
850+
# 8 inputs, 2 concatenations of small inputs, 2 outputs
851+
assert len(dsk) <= 8 + 2 + 2
851852

852853

853854
@pytest.mark.parametrize(
@@ -1337,7 +1338,7 @@ async def test_partial_rechunk_taskgroups(c, s):
13371338
),
13381339
timeout=5,
13391340
)
1340-
assert len(s.task_groups) < 6
1341+
assert len(s.task_groups) < 7
13411342

13421343

13431344
@pytest.mark.parametrize(
@@ -1351,25 +1352,107 @@ async def test_partial_rechunk_taskgroups(c, s):
13511352
],
13521353
)
13531354
def test_calculate_prechunking_1d(old, new, expected):
1354-
actual = _calculate_prechunking(old, new)
1355+
actual = _calculate_prechunking(old, new, np.dtype, None)
13551356
assert actual == expected
13561357

13571358

13581359
@pytest.mark.parametrize(
13591360
["old", "new", "expected"],
13601361
[
13611362
[((2, 2), (3, 3)), ((2, 2), (3, 3)), ((2, 2), (3, 3))],
1362-
[((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))],
1363+
[((2, 2), (3, 3)), ((4,), (3, 3)), ((4,), (3, 3))],
13631364
[((2, 2), (3, 3)), ((1, 1, 1, 1), (3, 3)), ((2, 2), (3, 3))],
13641365
[
13651366
((2, 2, 2), (3, 3, 3)),
13661367
((1, 2, 2, 1), (2, 3, 4)),
1367-
((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)),
1368+
((1, 2, 2, 1), (2, 3, 4)),
13681369
],
13691370
[((1, np.nan), (3, 3)), ((1, np.nan), (2, 2, 2)), ((1, np.nan), (2, 1, 1, 2))],
1370-
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))],
1371+
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (3,))],
13711372
],
13721373
)
13731374
def test_calculate_prechunking_2d(old, new, expected):
1374-
actual = _calculate_prechunking(old, new)
1375+
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
1376+
assert actual == expected
1377+
1378+
1379+
@pytest.mark.parametrize(
1380+
["old", "new", "expected"],
1381+
[
1382+
(
1383+
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
1384+
((1, 1, 1, 1), (4,), (2, 2)),
1385+
((2, 2), (4,), (1, 1, 1, 1)),
1386+
),
1387+
(
1388+
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
1389+
((1, 1, 1, 1), (2, 2), (2, 2)),
1390+
((2, 2), (2, 2), (2, 2)),
1391+
),
1392+
(
1393+
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
1394+
((1, 1, 1, 1), (2, 2), (4,)),
1395+
((2, 2), (2, 2), (2, 2)),
1396+
),
1397+
(
1398+
((1, 1, 1, 1), (1, 1, 1, 1), (2, 2)),
1399+
((2, 2), (4,), (1, 1, 1, 1)),
1400+
((2, 2), (2, 2), (2, 2)),
1401+
),
1402+
],
1403+
)
1404+
def test_calculate_prechunking_3d(old, new, expected):
1405+
with dask.config.set({"array.chunk-size": "16 B"}):
1406+
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
1407+
assert actual == expected
1408+
1409+
1410+
@pytest.mark.parametrize(
1411+
["chunk_size", "expected"],
1412+
[
1413+
("1 B", ((10,), (1,) * 10)),
1414+
("20 B", ((10,), (1,) * 10)),
1415+
("40 B", ((10,), (2, 2, 1, 2, 2, 1))),
1416+
("100 B", ((10,), (5, 5))),
1417+
],
1418+
)
1419+
def test_calculate_prechunking_concatenation(chunk_size, expected):
1420+
old = ((10,), (1,) * 10)
1421+
new = ((2,) * 5, (5, 5))
1422+
with dask.config.set({"array.chunk-size": chunk_size}):
1423+
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
1424+
assert actual == expected
1425+
1426+
1427+
def test_calculate_prechunking_does_not_concatenate_object_type():
1428+
old = ((10,), (1,) * 10)
1429+
new = ((2,) * 5, (5, 5))
1430+
1431+
# Ensure that int dtypes get concatenated
1432+
new = ((2,) * 5, (5, 5))
1433+
with dask.config.set({"array.chunk-size": "100 B"}):
1434+
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
1435+
assert actual == ((10,), (5, 5))
1436+
1437+
# Ensure object dtype chunks do not get concatenated
1438+
with dask.config.set({"array.chunk-size": "100 B"}):
1439+
actual = _calculate_prechunking(old, new, np.dtype(object), None)
1440+
assert actual == old
1441+
1442+
1443+
@pytest.mark.parametrize(
1444+
["old", "new", "expected"],
1445+
[
1446+
[((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))],
1447+
[
1448+
((2, 2, 2), (3, 3, 3)),
1449+
((1, 2, 2, 1), (2, 3, 4)),
1450+
((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)),
1451+
],
1452+
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))],
1453+
],
1454+
)
1455+
def test_calculate_prechunking_splitting(old, new, expected):
1456+
# _calculate_prechunking does not concatenate on object
1457+
actual = _calculate_prechunking(old, new, np.dtype(object), None)
13751458
assert actual == expected

0 commit comments

Comments
 (0)