Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Scalar container type from polars interpreter #15953

Merged
merged 14 commits into from
Jun 11, 2024
Prev Previous commit
Next Next commit
Simplify logic for broadcasting
  • Loading branch information
wence- committed Jun 10, 2024
commit 57097194da64111c17f067b4a2fa4197749b7ded
24 changes: 9 additions & 15 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cudf_polars.utils import sorting

if TYPE_CHECKING:
from collections.abc import MutableMapping
from collections.abc import MutableMapping, Set
from typing import Literal

from cudf_polars.typing import Schema
Expand Down Expand Up @@ -96,31 +96,25 @@ def broadcast(
``target_length`` is provided and not all columns are length-1
(i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
"""
lengths = {column.obj.size() for column in columns}
if len(lengths - {1}) > 1:
raise RuntimeError("Mismatching column lengths")
lengths: Set[int] = {column.obj.size() for column in columns}
vyasr marked this conversation as resolved.
Show resolved Hide resolved
if lengths == {1}:
if target_length is None:
return list(columns)
nrows = target_length
elif len(lengths) == 1:
if target_length is not None and target_length not in lengths:
raise RuntimeError(
"Cannot broadcast columns of length "
f"{lengths.pop()} to {target_length=}"
)
return list(columns)
else:
(nrows,) = lengths - {1}
if target_length is not None and target_length != nrows:
try:
(nrows,) = lengths - {1}
wence- marked this conversation as resolved.
Show resolved Hide resolved
except ValueError as e:
raise RuntimeError("Mismatching column lengths") from e
if target_length is not None and nrows != target_length:
raise RuntimeError(
f"Cannot broadcast columns of length {nrows} to {target_length=}"
f"Cannot broadcast columns of length {nrows=} to {target_length=}"
)
wence- marked this conversation as resolved.
Show resolved Hide resolved
return [
column
if column.obj.size() != 1
else NamedColumn(
plc.Column.from_scalar(plc.copying.get_element(column.obj, 0), nrows),
plc.Column.from_scalar(column.obj_scalar, nrows),
column.name,
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
Expand Down