Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dask_expr migration into dask.dataframe #17704

Merged
merged 12 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions python/dask_cudf/dask_cudf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
import cudf

from . import backends, io # noqa: F401
from ._expr import collection # noqa: F401
from ._expr.expr import _patch_dask_expr
from ._version import __git_commit__, __version__ # noqa: F401
from .core import DataFrame, Index, Series, _deprecated_api, concat, from_cudf

if not (QUERY_PLANNING_ON := dd._dask_expr_enabled()):
raise ValueError(
"The legacy DataFrame API is not supported in dask_cudf>24.12. "
"Please enable query-planning, or downgrade to dask_cudf<=24.12"
)


def read_csv(*args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
Expand Down
97 changes: 96 additions & 1 deletion python/dask_cudf/dask_cudf/_expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,96 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.

from packaging.version import Version

import dask

if Version(dask.__version__) > Version("2024.12.1"):
import dask.dataframe.dask_expr._shuffle as _shuffle_module
from dask.dataframe.dask_expr import (
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
from_dict,
get_collection_type,
new_collection,
)
from dask.dataframe.dask_expr._cumulative import (
CumulativeBlockwise,
)
from dask.dataframe.dask_expr._expr import (
Elemwise,
Expr,
RenameAxis,
VarColumns,
)
from dask.dataframe.dask_expr._groupby import (
DecomposableGroupbyAggregation,
GroupBy as DXGroupBy,
GroupbyAggregation,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask.dataframe.dask_expr._reductions import (
Reduction,
Var,
)
from dask.dataframe.dask_expr._util import (
_convert_to_list,
_raise_if_object_series,
is_scalar,
)
from dask.dataframe.dask_expr.io.io import (
FusedIO,
FusedParquetIO,
)
from dask.dataframe.dask_expr.io.parquet import (
FragmentWrapper,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
)
else:
import dask_expr._shuffle as _shuffle_module # noqa: F401
from dask_expr import ( # noqa: F401
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
from_dict,
get_collection_type,
new_collection,
)
from dask_expr._cumulative import CumulativeBlockwise # noqa: F401
from dask_expr._expr import ( # noqa: F401
Elemwise,
Expr,
RenameAxis,
VarColumns,
)
from dask_expr._groupby import ( # noqa: F401
DecomposableGroupbyAggregation,
GroupBy as DXGroupBy,
GroupbyAggregation,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._reductions import Reduction, Var # noqa: F401
from dask_expr._util import ( # noqa: F401
_convert_to_list,
_raise_if_object_series,
is_scalar,
)
from dask_expr.io.io import FusedIO, FusedParquetIO # noqa: F401
from dask_expr.io.parquet import ( # noqa: F401
FragmentWrapper,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
)

from dask.dataframe import _dask_expr_enabled

if not _dask_expr_enabled():
raise ValueError(
"The legacy DataFrame API is not supported for RAPIDS >24.12. "
"The 'dataframe.query-planning' config must be True or None."
)
20 changes: 10 additions & 10 deletions python/dask_cudf/dask_cudf/_expr/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
import warnings
from functools import cached_property

from dask_expr import (
DataFrame as DXDataFrame,
FrameBase,
Index as DXIndex,
Series as DXSeries,
get_collection_type,
)
from dask_expr._collection import new_collection
from dask_expr._util import _raise_if_object_series

from dask import config
from dask.dataframe.core import is_dataframe_like
from dask.dataframe.dispatch import get_parallel_type
from dask.typing import no_default

import cudf

from dask_cudf._expr import (
DXDataFrame,
DXIndex,
DXSeries,
FrameBase,
_raise_if_object_series,
get_collection_type,
new_collection,
)

##
## Custom collection classes
##
Expand Down
18 changes: 12 additions & 6 deletions python/dask_cudf/dask_cudf/_expr/expr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
import functools

import dask_expr._shuffle as _shuffle_module
from dask_expr import new_collection
from dask_expr._cumulative import CumulativeBlockwise
from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns
from dask_expr._reductions import Reduction, Var

from dask.dataframe.dispatch import (
is_categorical_dtype,
make_meta,
Expand All @@ -17,6 +11,18 @@

import cudf

from dask_cudf._expr import (
CumulativeBlockwise,
Elemwise,
Expr,
Reduction,
RenameAxis,
Var,
VarColumns,
_shuffle_module,
new_collection,
)

##
## Custom expressions
##
Expand Down
19 changes: 10 additions & 9 deletions python/dask_cudf/dask_cudf/_expr/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@

import numpy as np
import pandas as pd
from dask_expr._collection import new_collection
from dask_expr._groupby import (
DecomposableGroupbyAggregation,
GroupBy as DXGroupBy,
GroupbyAggregation,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._util import is_scalar

from dask.dataframe.core import _concat
from dask.dataframe.groupby import Aggregation

from cudf.core.groupby.groupby import _deprecate_collect
from cudf.utils.performance_tracking import _dask_cudf_performance_tracking

from dask_cudf._expr import (
DecomposableGroupbyAggregation,
DXGroupBy,
DXSeriesGroupBy,
GroupbyAggregation,
SingleAggregation,
is_scalar,
new_collection,
)

##
## Fused groupby aggregations
##
Expand Down
61 changes: 15 additions & 46 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,16 +543,6 @@ def to_cudf_dispatch_from_cudf(data, **kwargs):
return data


# Define the "cudf" backend for "legacy" Dask DataFrame
class LegacyCudfBackendEntrypoint(DataFrameBackendEntrypoint):
"""Backend-entrypoint class for legacy Dask-DataFrame

This class is registered under the name "cudf" for the
``dask.dataframe.backends`` entrypoint in ``pyproject.toml``.
This "legacy" backend is only used for CSV support.
"""


# Define the "cudf" backend for expr-based Dask DataFrame
class CudfBackendEntrypoint(DataFrameBackendEntrypoint):
"""Backend-entrypoint class for Dask-Expressions
Expand All @@ -566,20 +556,19 @@ class CudfBackendEntrypoint(DataFrameBackendEntrypoint):
Examples
--------
>>> import dask
>>> import dask_expr as dx
>>> import dask.dataframe as dd
>>> with dask.config.set({"dataframe.backend": "cudf"}):
... ddf = dx.from_dict({"a": range(10)})
... ddf = dd.from_dict({"a": range(10)})
>>> type(ddf._meta)
<class 'cudf.core.dataframe.DataFrame'>
"""

@staticmethod
def to_backend(data, **kwargs):
import dask_expr as dx

from dask_cudf._expr import new_collection
from dask_cudf._expr.expr import ToCudfBackend

return dx.new_collection(ToCudfBackend(data, kwargs))
return new_collection(ToCudfBackend(data, kwargs))

@staticmethod
def from_dict(
Expand All @@ -590,10 +579,10 @@ def from_dict(
columns=None,
constructor=cudf.DataFrame,
):
import dask_expr as dx
from dask_cudf._expr import from_dict

return _default_backend(
dx.from_dict,
from_dict,
data,
npartitions=npartitions,
orient=orient,
Expand All @@ -617,35 +606,15 @@ def read_csv(
storage_options=None,
**kwargs,
):
try:
# TODO: Remove when cudf is pinned to dask>2024.12.0
import dask_expr as dx
from dask_expr.io.csv import ReadCSV
from fsspec.utils import stringify_path

if not isinstance(path, str):
path = stringify_path(path)
return dx.new_collection(
ReadCSV(
path,
dtype_backend=dtype_backend,
storage_options=storage_options,
kwargs=kwargs,
header=header,
dataframe_backend="cudf",
)
)
except ImportError:
# Requires dask>2024.12.0
from dask_cudf.io.csv import read_csv

return read_csv(
path,
*args,
header=header,
storage_options=storage_options,
**kwargs,
)
from dask_cudf.io.csv import read_csv

return read_csv(
path,
*args,
header=header,
storage_options=storage_options,
**kwargs,
)

@staticmethod
def read_json(*args, **kwargs):
Expand Down
35 changes: 14 additions & 21 deletions python/dask_cudf/dask_cudf/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,26 @@

import numpy as np
import pandas as pd
from dask_expr._expr import Elemwise
from dask_expr._util import _convert_to_list
from dask_expr.io.io import FusedIO, FusedParquetIO
from dask_expr.io.parquet import (
FragmentWrapper,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
)

from dask._task_spec import Task
from dask._task_spec import List as TaskList, Task
from dask.dataframe.io.parquet.arrow import _filters_to_expression
from dask.dataframe.io.parquet.core import ParquetFunctionWrapper
from dask.tokenize import tokenize
from dask.utils import parse_bytes

try:
# TODO: Remove try/except when dask>2024.11.2
from dask._task_spec import List as TaskList
except ImportError:

def TaskList(*x):
return list(x)


import cudf

from dask_cudf._expr import (
Elemwise,
FragmentWrapper,
FusedIO,
FusedParquetIO,
ReadParquetFSSpec,
ReadParquetPyarrowFS,
_convert_to_list,
new_collection,
)

# Dask-expr imports CudfEngine from this module
from dask_cudf._legacy.io.parquet import CudfEngine
from dask_cudf.core import _deprecated_api
Expand Down Expand Up @@ -698,7 +692,6 @@ def read_parquet_expr(
using the ``read`` key-word argument.
"""

import dask_expr as dx
from fsspec.utils import stringify_path
from pyarrow import fs as pa_fs

Expand Down Expand Up @@ -785,7 +778,7 @@ def read_parquet_expr(
"parquet_file_extension is not supported when using the pyarrow filesystem."
)

return dx.new_collection(
return new_collection(
NoOp(
CudfReadParquetPyarrowFS(
path,
Expand All @@ -806,7 +799,7 @@ def read_parquet_expr(
)
)

return dx.new_collection(
return new_collection(
NoOp(
CudfReadParquetFSSpec(
path,
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ classifiers = [
]

[project.entry-points."dask.dataframe.backends"]
cudf = "dask_cudf.backends:LegacyCudfBackendEntrypoint"
cudf = "dask_cudf.backends:CudfBackendEntrypoint"

[project.entry-points."dask_expr.dataframe.backends"]
cudf = "dask_cudf.backends:CudfBackendEntrypoint"
Expand Down
Loading