Skip to content

Commit

Permalink
Remove Scalar container type from polars interpreter (#15953)
Browse files Browse the repository at this point in the history
Now we always return columns and, where usage of a scalar might be
correct (for example broadcasting in binops), we check if the column
is "actually" a scalar and extract it.

This is slightly annoying because we have to introspect things in
various places. But without changing libcudf to treat length-1 columns
as always broadcastable like scalars this is, I think, the best we can
do.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - https://github.com/brandon-b-miller
  - James Lamb (https://github.com/jameslamb)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #15953
  • Loading branch information
wence- authored Jun 11, 2024
1 parent 66c2f4f commit 22ac996
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 97 deletions.
8 changes: 7 additions & 1 deletion python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

from __future__ import annotations

from cudf_polars._version import __git_commit__, __version__
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

__all__: list[str] = ["execute_with_cudf", "translate_ir"]
__all__: list[str] = [
"execute_with_cudf",
"translate_ir",
"__git_commit__",
"__version__",
]
3 changes: 1 addition & 2 deletions python/cudf_polars/cudf_polars/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from __future__ import annotations

__all__: list[str] = ["DataFrame", "Column", "NamedColumn", "Scalar"]
__all__: list[str] = ["DataFrame", "Column", "NamedColumn"]

from cudf_polars.containers.column import Column, NamedColumn
from cudf_polars.containers.dataframe import DataFrame
from cudf_polars.containers.scalar import Scalar
28 changes: 27 additions & 1 deletion python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@


class Column:
"""A column with sortedness metadata."""
"""An immutable column with sortedness metadata."""

obj: plc.Column
is_sorted: plc.types.Sorted
order: plc.types.Order
null_order: plc.types.NullOrder
is_scalar: bool

def __init__(
self,
Expand All @@ -33,10 +34,33 @@ def __init__(
null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE,
):
self.obj = column
self.is_scalar = self.obj.size() == 1
if self.obj.size() <= 1:
is_sorted = plc.types.Sorted.YES
self.is_sorted = is_sorted
self.order = order
self.null_order = null_order

@functools.cached_property
def obj_scalar(self) -> plc.Scalar:
"""
A copy of the column object as a pylibcudf Scalar.
Returns
-------
pylibcudf Scalar object.
Raises
------
ValueError
If the column is not length-1.
"""
if not self.is_scalar:
raise ValueError(
f"Cannot convert a column of length {self.obj.size()} to scalar"
)
return plc.copying.get_element(self.obj, 0)

def sorted_like(self, like: Column, /) -> Self:
"""
Copy sortedness properties from a column onto self.
Expand Down Expand Up @@ -81,6 +105,8 @@ def set_sorted(
-------
Self with metadata set.
"""
if self.obj.size() <= 1:
is_sorted = plc.types.Sorted.YES
self.is_sorted = is_sorted
self.order = order
self.null_order = null_order
Expand Down
6 changes: 2 additions & 4 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataFrame:
"""A representation of a dataframe."""

columns: list[NamedColumn]
table: plc.Table | None
table: plc.Table

def __init__(self, columns: Sequence[NamedColumn]) -> None:
self.columns = list(columns)
Expand All @@ -41,7 +41,7 @@ def __init__(self, columns: Sequence[NamedColumn]) -> None:

def copy(self) -> Self:
"""Return a shallow copy of self."""
return type(self)(self.columns)
return type(self)([c.copy() for c in self.columns])

def to_polars(self) -> pl.DataFrame:
"""Convert to a polars DataFrame."""
Expand Down Expand Up @@ -70,8 +70,6 @@ def num_columns(self) -> int:
@cached_property
def num_rows(self) -> int:
"""Number of rows."""
if self.table is None:
raise ValueError("Number of rows of frame with scalars makes no sense")
return self.table.num_rows()

@classmethod
Expand Down
23 changes: 0 additions & 23 deletions python/cudf_polars/cudf_polars/containers/scalar.py

This file was deleted.

114 changes: 69 additions & 45 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
DSL nodes for the polars expression language.
An expression node is a function, `DataFrame -> Column` or `DataFrame -> Scalar`.
An expression node is a function, `DataFrame -> Column`.
The evaluation context is provided by a LogicalPlan node, and can
affect the evaluation rule as well as providing the dataframe input.
Expand All @@ -26,7 +26,7 @@

import cudf._lib.pylibcudf as plc

from cudf_polars.containers import Column, NamedColumn, Scalar
from cudf_polars.containers import Column, NamedColumn
from cudf_polars.utils import sorting

if TYPE_CHECKING:
Expand Down Expand Up @@ -165,7 +165,7 @@ def do_evaluate(
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column: # TODO: return type is a lie for Literal
) -> Column:
"""
Evaluate this expression given a dataframe for context.
Expand All @@ -187,8 +187,7 @@ def do_evaluate(
Returns
-------
Column representing the evaluation of the expression (or maybe
a scalar).
Column representing the evaluation of the expression.
Raises
------
Expand All @@ -205,7 +204,7 @@ def evaluate(
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column: # TODO: return type is a lie for Literal
) -> Column:
"""
Evaluate this expression given a dataframe for context.
Expand All @@ -222,23 +221,13 @@ def evaluate(
Notes
-----
Individual subclasses should implement :meth:`do_allocate`,
Individual subclasses should implement :meth:`do_evaluate`,
this method provides logic to handle lookups in the
substitution mapping.
The typed return value of :class:`Column` is not true when
evaluating :class:`Literal` nodes (which instead produce
:class:`Scalar` objects). However, these duck-type to having a
pylibcudf container object inside them, and usually they end
up appearing in binary expressions which pylibcudf handles
appropriately since there are overloads for (column, scalar)
pairs. We don't have to handle (scalar, scalar) in binops
since the polars optimizer has a constant-folding pass.
Returns
-------
Column representing the evaluation of the expression (or maybe
a scalar).
Column representing the evaluation of the expression.
Raises
------
Expand Down Expand Up @@ -319,24 +308,35 @@ def evaluate(
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> NamedColumn:
"""Evaluate this expression given a dataframe for context."""
"""
Evaluate this expression given a dataframe for context.
Parameters
----------
df
DataFrame providing context
context
Execution context
mapping
Substitution mapping
Returns
-------
NamedColumn attaching a name to an evaluated Column
See Also
--------
:meth:`Expr.evaluate` for details, this function just adds the
name to a column produced from an expression.
"""
obj = self.value.evaluate(df, context=context, mapping=mapping)
if isinstance(obj, Scalar):
return NamedColumn(
plc.Column.from_scalar(obj.obj, 1),
self.name,
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
else:
return NamedColumn(
obj.obj,
self.name,
is_sorted=obj.is_sorted,
order=obj.order,
null_order=obj.null_order,
)
return NamedColumn(
obj.obj,
self.name,
is_sorted=obj.is_sorted,
order=obj.order,
null_order=obj.null_order,
)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
Expand All @@ -363,7 +363,7 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
# datatype of pyarrow scalar is correct by construction.
return Scalar(plc.interop.from_arrow(self.value)) # type: ignore
return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1))


class Col(Expr):
Expand Down Expand Up @@ -402,8 +402,14 @@ def do_evaluate(
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
# TODO: type is wrong, and dtype
return df.num_rows # type: ignore
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(df.num_rows, type=plc.interop.to_arrow(self.dtype))
),
1,
)
)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
Expand Down Expand Up @@ -664,10 +670,24 @@ def do_evaluate(
return Column(plc.strings.case.to_upper(column.obj))
elif self.name == pl_expr.StringFunction.EndsWith:
column, suffix = columns
return Column(plc.strings.find.ends_with(column.obj, suffix.obj))
return Column(
plc.strings.find.ends_with(
column.obj,
suffix.obj_scalar
if column.obj.size() != suffix.obj.size() and suffix.is_scalar
else suffix.obj,
)
)
elif self.name == pl_expr.StringFunction.StartsWith:
column, suffix = columns
return Column(plc.strings.find.starts_with(column.obj, suffix.obj))
column, prefix = columns
return Column(
plc.strings.find.starts_with(
column.obj,
prefix.obj_scalar
if column.obj.size() != prefix.obj.size() and prefix.is_scalar
else prefix.obj,
)
)
else:
raise NotImplementedError(f"StringFunction {self.name}")

Expand Down Expand Up @@ -875,9 +895,6 @@ def __init__(
self, dtype: plc.DataType, name: str, options: Any, value: Expr
) -> None:
super().__init__(dtype)
# TODO: fix polars name
if name == "nunique":
name = "n_unique"
self.name = name
self.options = options
self.children = (value,)
Expand Down Expand Up @@ -1092,8 +1109,15 @@ def do_evaluate(
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
lop = left.obj
rop = right.obj
if left.obj.size() != right.obj.size():
if left.is_scalar:
lop = left.obj_scalar
elif right.is_scalar:
rop = right.obj_scalar
return Column(
plc.binaryop.binary_operation(left.obj, right.obj, self.op, self.dtype),
plc.binaryop.binary_operation(lop, rop, self.op, self.dtype),
)

def collect_agg(self, *, depth: int) -> AggInfo:
Expand Down
Loading

0 comments on commit 22ac996

Please sign in to comment.