Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

from polars.exceptions import InvalidOperationError
from polars.polars import dtype_str_repr

import pylibcudf as plc

Expand Down Expand Up @@ -137,6 +138,7 @@ def from_polars(cls, obj: pl_expr.StringFunction) -> Self:
Name.Reverse,
Name.Tail,
Name.Titlecase,
Name.ZFill,
}
__slots__ = ("_regex_program", "name", "options")
_non_child = ("dtype", "name", "options")
Expand Down Expand Up @@ -264,6 +266,17 @@ def _validate_input(self) -> None:
raise NotImplementedError(
"strip operations only support scalar patterns"
)
elif self.name is StringFunction.Name.ZFill:
if isinstance(self.children[1], Literal):
_, width = self.children
assert isinstance(width, Literal)
if width.value is not None and width.value < 0:
dtypestr = dtype_str_repr(width.dtype.polars)
raise InvalidOperationError(
f"conversion from `{dtypestr}` to `u64` "
f"failed in column 'literal' for 1 out of "
f"1 values: [{width.value}]"
) from None

@staticmethod
def _create_regex_program(
Expand Down Expand Up @@ -322,6 +335,63 @@ def do_evaluate(
),
dtype=self.dtype,
)
elif self.name is StringFunction.Name.ZFill:
# TODO: expensive validation
# polars pads based on bytes, libcudf by visual width
# only pass chars if the visual width matches the byte length
column = self.children[0].evaluate(df, context=context)
col_len_bytes = plc.strings.attributes.count_bytes(column.obj)
col_len_chars = plc.strings.attributes.count_characters(column.obj)
equal = plc.binaryop.binary_operation(
col_len_bytes,
col_len_chars,
plc.binaryop.BinaryOperator.NULL_EQUALS,
plc.DataType(plc.TypeId.BOOL8),
)
if not plc.reduce.reduce(
equal,
plc.aggregation.all(),
plc.DataType(plc.TypeId.BOOL8),
).to_py():
raise InvalidOperationError(
"zfill only supports ascii strings with no unicode characters"
)
if isinstance(self.children[1], Literal):
width = self.children[1]
assert isinstance(width, Literal)
if width.value is None:
return Column(
plc.Column.from_scalar(
plc.Scalar.from_py(None, self.dtype.plc),
column.size,
),
self.dtype,
)
return Column(
plc.strings.padding.zfill(column.obj, width.value), self.dtype
)
else:
col_width = self.children[1].evaluate(df, context=context)
assert isinstance(col_width, Column)
all_gt_0 = plc.binaryop.binary_operation(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do this introspection?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mroeschke , do we usually introspect data to error at this point? If we don't, then we risk maybe producing the wrong result in some edge cases, like a negative fill value in the zfill column overload. The scalar case is easy enough to introspect and throw, but this is a whole scan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a case where we wouldn't be able to match the result (not necessarily an error) that Polars produces?

If so, yes, I think there's some precedent on introspecting the data to raise an error. We would want to do this in _validate_input though so it can raise during translation. Plus we can do this introspection on the CPU by doing a similar option on Polars objects

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case this intends to match is the one where one of the fill elements in the column overload is negative, polars throws at runtime in this case. I suppose we want to match that behavior, but it's a performance hit to every call to the column overload :(

col_width.obj,
plc.Scalar.from_py(0, plc.DataType(plc.TypeId.INT64)),
plc.binaryop.BinaryOperator.GREATER_EQUAL,
plc.DataType(plc.TypeId.BOOL8),
)

if not plc.reduce.reduce(
all_gt_0,
plc.aggregation.all(),
plc.DataType(plc.TypeId.BOOL8),
).to_py():
raise InvalidOperationError("fill conversion failed.")

return Column(
plc.strings.padding.zfill_by_widths(column.obj, col_width.obj),
self.dtype,
)

elif self.name is StringFunction.Name.Contains:
child, arg = self.children
column = child.evaluate(df, context=context)
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def pytest_configure(config: pytest.Config) -> None:
"tests/unit/operations/test_group_by.py::test_group_by_binary_agg_with_literal": "Incorrect broadcasting of literals in groupby-agg",
"tests/unit/operations/test_group_by.py::test_group_by_lit_series": "Incorrect broadcasting of literals in groupby-agg",
"tests/unit/operations/test_join.py::test_cross_join_slice_pushdown": "Need to implement slice pushdown for cross joins",
"tests/unit/operations/namespaces/string/test_pad.py::test_str_zfill_unicode_not_respected": "polars doesn't add zeros for unicode characters.",
"tests/unit/operations/test_rolling.py::test_rolling_group_by_empty_groups_by_take_6330": "Ordering difference, might be polars bug",
"tests/unit/sql/test_cast.py::test_cast_errors[values0-values::uint8-conversion from `f64` to `u64` failed]": "Casting that raises not supported on GPU",
"tests/unit/sql/test_cast.py::test_cast_errors[values1-values::uint4-conversion from `i64` to `u32` failed]": "Casting that raises not supported on GPU",
Expand Down
100 changes: 100 additions & 0 deletions python/cudf_polars/tests/expressions/test_stringfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,106 @@ def test_string_join(ldf, ignore_nulls, delimiter):
assert_gpu_result_equal(q)


@pytest.mark.parametrize(
"fill",
[
0,
1,
2,
5,
999,
-1,
None,
],
)
@pytest.mark.parametrize(
"input_strings",
[
["1", "0"],
["123", "45"],
["", "0"],
["abc", "def"],
],
)
def test_string_zfill(fill, input_strings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you run these tests with polars 1.29? You might need to version guard these tests ie if POLARS_VERSION_LT_130

ldf = pl.LazyFrame({"a": input_strings})
q = ldf.select(pl.col("a").str.zfill(fill))

if fill is not None and fill < 0:
assert_collect_raises(
q,
polars_except=pl.exceptions.InvalidOperationError,
cudf_except=pl.exceptions.ComputeError,
)
else:
assert_gpu_result_equal(q)


@pytest.mark.parametrize(
"fill",
[
5
if not POLARS_VERSION_LT_130
else pytest.param(5, marks=pytest.mark.xfail(reason="fixed in Polars 1.30")),
Comment on lines +545 to +547
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
5
if not POLARS_VERSION_LT_130
else pytest.param(5, marks=pytest.mark.xfail(reason="fixed in Polars 1.30")),
pytest.param(5, marks=pytest.mark.xfail(POLARS_VERSION_LT_130, reason="fixed in Polars 1.30")),

(For a follow up PR if interested)

999
if not POLARS_VERSION_LT_130
else pytest.param(999, marks=pytest.mark.xfail(reason="fixed in Polars 1.30")),
],
)
def test_string_zfill_pl_129(fill):
ldf = pl.LazyFrame({"a": ["-1", "+2"]})
q = ldf.select(pl.col("a").str.zfill(fill))
assert_gpu_result_equal(q)


@pytest.mark.parametrize(
"fill",
[
0,
1,
2,
5
if not POLARS_VERSION_LT_130
else pytest.param(5, marks=pytest.mark.xfail(reason="fixed in Polars 1.30")),
999
if not POLARS_VERSION_LT_130
else pytest.param(999, marks=pytest.mark.xfail(reason="fixed in Polars 1.30")),
-1,
pytest.param(None, marks=pytest.mark.xfail(reason="None dtype")),
],
)
def test_string_zfill_column(fill):
ldf = pl.DataFrame(
{
"input_strings": ["1", "0", "123", "45", "", "0", "-1", "+2", "abc", "def"],
"fill": [fill] * 10,
}
).lazy()
q = ldf.select(pl.col("input_strings").str.zfill(pl.col("fill")))
if fill is not None and fill < 0:
assert_collect_raises(
q,
polars_except=pl.exceptions.InvalidOperationError,
cudf_except=pl.exceptions.InvalidOperationError
if not POLARS_VERSION_LT_130
else pl.exceptions.ComputeError,
)
else:
assert_gpu_result_equal(q)


def test_string_zfill_forbidden_chars():
ldf = pl.LazyFrame({"a": ["Café", "345", "東京", None]})
q = ldf.select(pl.col("a").str.zfill(3))
assert_collect_raises(
q,
polars_except=(),
cudf_except=pl.exceptions.InvalidOperationError
if not POLARS_VERSION_LT_130
else pl.exceptions.ComputeError,
)


@pytest.mark.parametrize(
"width",
[
Expand Down