Skip to content

Commit

Permalink
feat: Enable is_first/last_distinct for not nested non-numeric list (
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Apr 10, 2024
1 parent c758416 commit dc415b0
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
14 changes: 13 additions & 1 deletion crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,20 @@ impl DataType {
matches!(self, DataType::List(_))
}

/// Check if this [`DataType`] is a array
pub fn is_array(&self) -> bool {
#[cfg(feature = "dtype-array")]
{
matches!(self, DataType::Array(_, _))
}
#[cfg(not(feature = "dtype-array"))]
{
false
}
}

pub fn is_nested(&self) -> bool {
self.is_list() || self.is_struct()
self.is_list() || self.is_struct() || self.is_array()
}

/// Check if this [`DataType`] is a struct
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-ops/src/series/ops/is_first_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ pub fn is_first_distinct(s: &Series) -> PolarsResult<BooleanChunked> {
},
#[cfg(feature = "dtype-struct")]
Struct(_) => return is_first_distinct_struct(&s),
List(inner) if inner.is_numeric() => {
List(inner) => {
polars_ensure!(
!inner.is_nested(),
InvalidOperation: "`is_first_distinct` on list type is only allowed if the inner type is not nested."
);
let ca = s.list().unwrap();
return is_first_distinct_list(ca);
},
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-ops/src/series/ops/is_last_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ pub fn is_last_distinct(s: &Series) -> PolarsResult<BooleanChunked> {
},
#[cfg(feature = "dtype-struct")]
Struct(_) => return is_last_distinct_struct(&s),
List(inner) if inner.is_numeric() => {
List(inner) => {
polars_ensure!(
!inner.is_nested(),
InvalidOperation: "`is_last_distinct` on list type is only allowed if the inner type is not nested."
);
let ca = s.list().unwrap();
return is_last_distinct_list(ca);
},
Expand Down
51 changes: 43 additions & 8 deletions py-polars/tests/unit/operations/test_is_first_last_distinct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

import datetime
from typing import Any

import pytest

import polars as pl
Expand Down Expand Up @@ -39,13 +44,47 @@ def test_is_first_distinct_struct() -> None:
assert_frame_equal(result, expected)


def test_is_first_distinct_list() -> None:
lf = pl.LazyFrame({"a": [[1, 2], [3], [1, 2], [4, 5], [4, 5]]})
result = lf.select(pl.col("a").is_first_distinct())
expected = pl.LazyFrame({"a": [True, True, False, True, False]})
@pytest.mark.parametrize(
"data",
[
[[1, 2], [3], [1, 2], [4, None], [4, None], [], []],
[[True, None], [True], [True, None], [False], [False], [], []],
[[b"1", b"2"], [b"3"], [b"1", b"2"], [b"4", None], [b"4", None], [], []],
[["a", "b"], ["&"], ["a", "b"], ["...", None], ["...", None], [], []],
[
[datetime.date(2000, 10, 1), datetime.date(2001, 1, 30)],
[datetime.date(1949, 10, 1)],
[datetime.date(2000, 10, 1), datetime.date(2001, 1, 30)],
[datetime.date(1998, 7, 1), None],
[datetime.date(1998, 7, 1), None],
[],
[],
],
],
)
def test_is_first_last_distinct_list(data: list[list[Any] | None]) -> None:
lf = pl.LazyFrame({"a": data})
result = lf.select(
first=pl.col("a").is_first_distinct(), last=pl.col("a").is_last_distinct()
)
expected = pl.LazyFrame(
{
"first": [True, True, False, True, False, True, False],
"last": [False, True, True, False, True, False, True],
}
)
assert_frame_equal(result, expected)


def test_is_first_last_distinct_list_inner_nested() -> None:
df = pl.DataFrame({"a": [[[1, 2]], [[1, 2]]]})
err_msg = "only allowed if the inner type is not nested"
with pytest.raises(pl.InvalidOperationError, match=err_msg):
df.select(pl.col("a").is_first_distinct())
with pytest.raises(pl.InvalidOperationError, match=err_msg):
df.select(pl.col("a").is_last_distinct())


def test_is_first_distinct_various() -> None:
# numeric
s = pl.Series([1, 1, None, 2, None, 3, 3])
Expand Down Expand Up @@ -106,10 +145,6 @@ def test_is_last_distinct() -> None:
)
expected = [False, True, False, True, True, False, True]
assert s.is_last_distinct().to_list() == expected
# list
s = pl.Series([[1, 2], [1, 2], None, [2, 3], None, [3, 4], [3, 4]])
expected = [False, True, False, True, True, False, True]
assert s.is_last_distinct().to_list() == expected


@pytest.mark.parametrize("dtypes", [pl.Int32, pl.String, pl.Boolean, pl.List(pl.Int32)])
Expand Down

0 comments on commit dc415b0

Please sign in to comment.