Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .cuda_worker import CUDAWorker
from .explicit_comms.dataframe.shuffle import (
get_rearrange_by_column_wrapper,
get_default_shuffle_algorithm,
get_default_shuffle_method,
)
from .local_cuda_cluster import LocalCUDACluster
from .proxify_device_objects import proxify_decorator, unproxify_decorator
Expand All @@ -28,11 +28,11 @@
dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
dask.dataframe.shuffle.rearrange_by_column
)
# We have to replace all modules that imports Dask's `get_default_shuffle_algorithm()`
# We have to replace all modules that imports Dask's `get_default_shuffle_method()`
# TODO: introduce a shuffle-algorithm dispatcher in Dask so we don't need this hack
dask.dataframe.shuffle.get_default_shuffle_algorithm = get_default_shuffle_algorithm
dask.dataframe.multi.get_default_shuffle_algorithm = get_default_shuffle_algorithm
dask.bag.core.get_default_shuffle_algorithm = get_default_shuffle_algorithm
dask.dataframe.shuffle.get_default_shuffle_method = get_default_shuffle_method
dask.dataframe.multi.get_default_shuffle_method = get_default_shuffle_method
dask.bag.core.get_default_shuffle_method = get_default_shuffle_method


# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
Expand Down
4 changes: 2 additions & 2 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def wrapper(*args, **kwargs):
return wrapper


def get_default_shuffle_algorithm() -> str:
def get_default_shuffle_method() -> str:
"""Return the default shuffle algorithm used by Dask

This changes the default shuffle algorithm from "p2p" to "tasks"
Expand All @@ -594,4 +594,4 @@ def get_default_shuffle_algorithm() -> str:
ret = dask.config.get("dataframe.shuffle.algorithm", None)
if ret is None and _use_explicit_comms():
return "tasks"
return dask.utils.get_default_shuffle_algorithm()
return dask.utils.get_default_shuffle_method()