From 8730ced3d450cbb340b0e4be1d38cba2e0f891f5 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Tue, 9 Jan 2024 23:02:51 +0800 Subject: [PATCH] feat(python): Impl and dispatch arr.first/last to get (#13536) --- .../src/legacy/kernels/fixed_size_list.rs | 53 +++++++++++ crates/polars-arrow/src/legacy/kernels/mod.rs | 1 + .../src/chunked_array/ops/arity.rs | 23 ++++- .../polars-ops/src/chunked_array/array/get.rs | 43 +++++++++ .../polars-ops/src/chunked_array/array/mod.rs | 1 + .../src/chunked_array/array/namespace.rs | 6 ++ crates/polars-plan/src/dsl/array.rs | 10 +++ .../src/dsl/function_expr/array.rs | 13 ++- .../source/reference/expressions/array.rst | 3 + .../docs/source/reference/series/array.rst | 3 + py-polars/polars/expr/array.py | 87 +++++++++++++++++++ py-polars/polars/series/array.py | 75 ++++++++++++++++ py-polars/src/expr/array.rs | 4 + .../tests/unit/namespaces/array/test_array.py | 61 +++++++++++++ 14 files changed, 381 insertions(+), 2 deletions(-) create mode 100644 crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs create mode 100644 crates/polars-ops/src/chunked_array/array/get.rs diff --git a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs new file mode 100644 index 000000000000..1ace23b09d8a --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs @@ -0,0 +1,53 @@ +use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray}; +use crate::legacy::compute::take::take_unchecked; +use crate::legacy::prelude::*; +use crate::legacy::utils::CustomIterTools; + +fn sub_fixed_size_list_get_indexes_literal(width: usize, len: usize, index: i64) -> IdxArr { + (0..len) + .map(|i| { + if index >= width as i64 { + return None; + } + + index + .negative_to_usize(width) + .map(|idx| (idx + i * width) as IdxSize) + }) + .collect_trusted() +} + +fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray) -> IdxArr { + index + .iter() + .enumerate() + .map(|(i, idx)| { + if let Some(idx) = idx { + if *idx >= width as i64 { + return None; + } + + idx.negative_to_usize(width) + .map(|idx| (idx + i * width) as IdxSize) + } else { + None + } + }) + .collect_trusted() +} + +pub fn sub_fixed_size_list_get_literal(arr: &FixedSizeListArray, index: i64) -> ArrayRef { + let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index); + let values = arr.values(); + // Safety: + // the indices we generate are in bounds + unsafe { take_unchecked(&**values, &take_by) } +} + +pub fn sub_fixed_size_list_get(arr: &FixedSizeListArray, index: &PrimitiveArray) -> ArrayRef { + let take_by = sub_fixed_size_list_get_indexes(arr.size(), index); + let values = arr.values(); + // Safety: + // the indices we generate are in bounds + unsafe { take_unchecked(&**values, &take_by) } +} diff --git a/crates/polars-arrow/src/legacy/kernels/mod.rs b/crates/polars-arrow/src/legacy/kernels/mod.rs index c6a634ef8c22..2c93ea0eca9d 100644 --- a/crates/polars-arrow/src/legacy/kernels/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/mod.rs @@ -7,6 +7,7 @@ pub mod agg_mean; pub mod atan2; pub mod concatenate; pub mod ewm; +pub mod fixed_size_list; pub mod float; pub mod list; pub mod list_bytes_iter; diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index 884cf4237d8b..46a70587d298 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -2,9 +2,10 @@ use std::error::Error; use arrow::array::Array; use arrow::compute::utils::combine_validities_and; +use polars_error::PolarsResult; use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray}; -use crate::prelude::{ChunkedArray, PolarsDataType}; +use crate::prelude::{ChunkedArray, PolarsDataType, Series}; use crate::utils::{align_chunks_binary, align_chunks_ternary}; // We need this helper because for<'a> notation can't yet be applied properly @@ -444,6 +445,26 @@ where lhs.copy_with_chunks(chunks, keep_sorted, keep_fast_explode) } +#[inline] +pub fn binary_to_series( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> PolarsResult +where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(&T::Array, &U::Array) -> Box, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)) + .collect::>(); + Series::try_from((lhs.name(), chunks)) +} + /// Applies a kernel that produces `ArrayRef` of the same type. /// /// # Safety diff --git a/crates/polars-ops/src/chunked_array/array/get.rs b/crates/polars-ops/src/chunked_array/array/get.rs new file mode 100644 index 000000000000..6cb5630676e9 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/get.rs @@ -0,0 +1,43 @@ +use arrow::legacy::kernels::fixed_size_list::{ + sub_fixed_size_list_get, sub_fixed_size_list_get_literal, +}; +use polars_core::datatypes::ArrayChunked; +use polars_core::prelude::arity::binary_to_series; + +use super::*; + +fn array_get_literal(ca: &ArrayChunked, idx: i64) -> PolarsResult { + let chunks = ca + .downcast_iter() + .map(|arr| sub_fixed_size_list_get_literal(arr, idx)) + .collect::>(); + Series::try_from((ca.name(), chunks)) + .unwrap() + .cast(&ca.inner_dtype()) +} + +/// Get the value by literal index in the array. +/// So index `0` would return the first item of every sub-array +/// and index `-1` would return the last item of every sub-array +/// if an index is out of bounds, it will return a `None`. +pub fn array_get(ca: &ArrayChunked, index: &Int64Chunked) -> PolarsResult { + match index.len() { + 1 => { + let index = index.get(0); + if let Some(index) = index { + array_get_literal(ca, index) + } else { + polars_bail!(ComputeError: "unexpected null index received in `arr.get`") + } + }, + len if len == ca.len() => { + let out = binary_to_series(ca, index, |arr, idx| sub_fixed_size_list_get(arr, idx)); + out?.cast(&ca.inner_dtype()) + }, + len => polars_bail!( + ComputeError: + "`arr.get` expression got an index array of length {} while the array has {} elements", + len, ca.len() + ), + } +} diff --git a/crates/polars-ops/src/chunked_array/array/mod.rs b/crates/polars-ops/src/chunked_array/array/mod.rs index 1f54a6592b83..46c4d80d792a 100644 --- a/crates/polars-ops/src/chunked_array/array/mod.rs +++ b/crates/polars-ops/src/chunked_array/array/mod.rs @@ -1,5 +1,6 @@ #[cfg(feature = "array_any_all")] mod any_all; +mod get; mod min_max; mod namespace; mod sum_mean; diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index cca6ca86df09..0b7f80192c03 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -3,6 +3,7 @@ use super::*; use crate::chunked_array::array::sum_mean::sum_with_nulls; #[cfg(feature = "array_any_all")] use crate::prelude::array::any_all::{array_all, array_any}; +use crate::prelude::array::get::array_get; use crate::prelude::array::sum_mean::sum_array_numerical; use crate::series::ArgAgg; @@ -92,6 +93,11 @@ pub trait ArrayNameSpace: AsArray { opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize)) }) } + + fn array_get(&self, index: &Int64Chunked) -> PolarsResult { + let ca = self.as_array(); + array_get(ca, index) + } } impl ArrayNameSpace for ArrayChunked {} diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 7c267a9bdb04..4e74300c4461 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -79,4 +79,14 @@ impl ArrayNameSpace { self.0 .map_private(FunctionExpr::ArrayExpr(ArrayFunction::ArgMax)) } + + /// Get items in every sub-array by index. + pub fn get(self, index: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ArrayExpr(ArrayFunction::Get), + &[index], + false, + false, + ) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index a82182cce994..0de1415f36ba 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -1,7 +1,7 @@ use polars_ops::chunked_array::array::*; use super::*; -use crate::map; +use crate::{map, map_as_slice}; #[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -19,6 +19,7 @@ pub enum ArrayFunction { Reverse, ArgMin, ArgMax, + Get, } impl ArrayFunction { @@ -34,6 +35,7 @@ impl ArrayFunction { Sort(_) => mapper.with_same_dtype(), Reverse => mapper.with_same_dtype(), ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE), + Get => mapper.map_to_list_and_array_inner_dtype(), } } } @@ -63,6 +65,7 @@ impl Display for ArrayFunction { Reverse => "reverse", ArgMin => "arg_min", ArgMax => "arg_max", + Get => "get", }; write!(f, "arr.{name}") } @@ -85,6 +88,7 @@ impl From for SpecialEq> { Reverse => map!(reverse), ArgMin => map!(arg_min), ArgMax => map!(arg_max), + Get => map_as_slice!(get), } } } @@ -141,3 +145,10 @@ pub(super) fn arg_min(s: &Series) -> PolarsResult { pub(super) fn arg_max(s: &Series) -> PolarsResult { Ok(s.array()?.array_arg_max().into_series()) } + +pub(super) fn get(s: &[Series]) -> PolarsResult { + let ca = s[0].array()?; + let index = s[1].cast(&DataType::Int64)?; + let index = index.i64().unwrap(); + ca.array_get(index) +} diff --git a/py-polars/docs/source/reference/expressions/array.rst b/py-polars/docs/source/reference/expressions/array.rst index 401ae0c862e4..36481077ac4f 100644 --- a/py-polars/docs/source/reference/expressions/array.rst +++ b/py-polars/docs/source/reference/expressions/array.rst @@ -20,3 +20,6 @@ The following methods are available under the `expr.arr` attribute. Expr.arr.reverse Expr.arr.arg_min Expr.arr.arg_max + Expr.arr.get + Expr.arr.first + Expr.arr.last diff --git a/py-polars/docs/source/reference/series/array.rst b/py-polars/docs/source/reference/series/array.rst index 901ac1d4adfa..55af8b11783c 100644 --- a/py-polars/docs/source/reference/series/array.rst +++ b/py-polars/docs/source/reference/series/array.rst @@ -20,3 +20,6 @@ The following methods are available under the `Series.arr` attribute. Series.arr.reverse Series.arr.arg_min Series.arr.arg_max + Series.arr.get + Series.arr.first + Series.arr.last diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 2693b2299e23..812876c1dd58 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -2,10 +2,12 @@ from typing import TYPE_CHECKING +from polars.utils._parse_expr_input import parse_as_expression from polars.utils._wrap import wrap_expr if TYPE_CHECKING: from polars import Expr + from polars.type_aliases import IntoExprColumn class ExprArrayNameSpace: @@ -340,3 +342,88 @@ def arg_max(self) -> Expr: """ return wrap_expr(self._pyexpr.arr_arg_max()) + + def get(self, index: int | IntoExprColumn) -> Expr: + """ + Get the value by index in the sub-arrays. + + So index `0` would return the first item of every sublist + and index `-1` would return the last item of every sublist + if an index is out of bounds, it will return a `None`. + + Parameters + ---------- + index + Index to return per sub-array + + Examples + -------- + >>> df = pl.DataFrame( + ... {"arr": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "idx": [1, -2, 4]}, + ... schema={"arr": pl.Array(pl.Int32, 3), "idx": pl.Int32}, + ... ) + >>> df.with_columns(get=pl.col("arr").arr.get("idx")) + shape: (3, 3) + ┌───────────────┬─────┬──────┐ + │ arr ┆ idx ┆ get │ + │ --- ┆ --- ┆ --- │ + │ array[i32, 3] ┆ i32 ┆ i32 │ + ╞═══════════════╪═════╪══════╡ + │ [1, 2, 3] ┆ 1 ┆ 2 │ + │ [4, 5, 6] ┆ -2 ┆ 5 │ + │ [7, 8, 9] ┆ 4 ┆ null │ + └───────────────┴─────┴──────┘ + + """ + index = parse_as_expression(index) + return wrap_expr(self._pyexpr.arr_get(index)) + + def first(self) -> Expr: + """ + Get the first value of the sub-arrays. + + Examples + -------- + >>> df = pl.DataFrame( + ... {"a": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]}, + ... schema={"a": pl.Array(pl.Int32, 3)}, + ... ) + >>> df.with_columns(first=pl.col("a").arr.first()) + shape: (3, 2) + ┌───────────────┬───────┐ + │ a ┆ first │ + │ --- ┆ --- │ + │ array[i32, 3] ┆ i32 │ + ╞═══════════════╪═══════╡ + │ [1, 2, 3] ┆ 1 │ + │ [4, 5, 6] ┆ 4 │ + │ [7, 8, 9] ┆ 7 │ + └───────────────┴───────┘ + + """ + return self.get(0) + + def last(self) -> Expr: + """ + Get the last value of the sub-arrays. + + Examples + -------- + >>> df = pl.DataFrame( + ... {"a": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]}, + ... schema={"a": pl.Array(pl.Int32, 3)}, + ... ) + >>> df.with_columns(last=pl.col("a").arr.last()) + shape: (3, 2) + ┌───────────────┬──────┐ + │ a ┆ last │ + │ --- ┆ --- │ + │ array[i32, 3] ┆ i32 │ + ╞═══════════════╪══════╡ + │ [1, 2, 3] ┆ 3 │ + │ [4, 5, 6] ┆ 6 │ + │ [7, 8, 9] ┆ 9 │ + └───────────────┴──────┘ + + """ + return self.get(-1) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 3ff18ef5c8d4..2033ff59148c 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from polars import Series from polars.polars import PySeries + from polars.type_aliases import IntoExprColumn @expr_dispatch @@ -266,3 +267,77 @@ def arg_max(self) -> Series: ] """ + + def get(self, index: int | IntoExprColumn) -> Series: + """ + Get the value by index in the sub-arrays. + + So index `0` would return the first item of every sublist + and index `-1` would return the last item of every sublist + if an index is out of bounds, it will return a `None`. + + Parameters + ---------- + index + Index to return per sublist + + Returns + ------- + Series + Series of innter data type. + + Examples + -------- + >>> s = pl.Series( + ... "a", [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=pl.Array(pl.Int32, 3) + ... ) + >>> s.arr.get(pl.Series([1, -2, 4])) + shape: (3,) + Series: 'a' [i32] + [ + 2 + 5 + null + ] + + """ + + def first(self) -> Series: + """ + Get the first value of the sub-arrays. + + Examples + -------- + >>> s = pl.Series( + ... "a", [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=pl.Array(pl.Int32, 3) + ... ) + >>> s.arr.first() + shape: (3,) + Series: 'a' [i32] + [ + 1 + 4 + 7 + ] + + """ + + def last(self) -> Series: + """ + Get the last value of the sub-arrays. + + Examples + -------- + >>> s = pl.Series( + ... "a", [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=pl.Array(pl.Int32, 3) + ... ) + >>> s.arr.last() + shape: (3,) + Series: 'a' [i32] + [ + 3 + 6 + 9 + ] + + """ diff --git a/py-polars/src/expr/array.rs b/py-polars/src/expr/array.rs index 497271d1a813..2d9193366da8 100644 --- a/py-polars/src/expr/array.rs +++ b/py-polars/src/expr/array.rs @@ -59,4 +59,8 @@ impl PyExpr { fn arr_arg_max(&self) -> Self { self.inner.clone().arr().arg_max().into() } + + fn arr_get(&self, index: PyExpr) -> Self { + self.inner.clone().arr().get(index.inner).into() + } } diff --git a/py-polars/tests/unit/namespaces/array/test_array.py b/py-polars/tests/unit/namespaces/array/test_array.py index 5803ff8aca12..bd36e1301776 100644 --- a/py-polars/tests/unit/namespaces/array/test_array.py +++ b/py-polars/tests/unit/namespaces/array/test_array.py @@ -1,3 +1,5 @@ +import datetime + import numpy as np import pytest @@ -102,3 +104,62 @@ def test_array_arg_min_max() -> None: assert_series_equal(s.arr.arg_min(), expected) expected = pl.Series("a", [2, 0], dtype=pl.UInt32) assert_series_equal(s.arr.arg_max(), expected) + + +def test_array_get() -> None: + # test index literal + s = pl.Series( + "a", + [[1, 2, 3, 4], [5, 6, None, None], [7, 8, 9, 10]], + dtype=pl.Array(pl.Int64, 4), + ) + out = s.arr.get(1) + expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64) + assert_series_equal(out, expected) + + # test index expr + out = s.arr.get(pl.Series([1, -2, 4])) + expected = pl.Series("a", [2, None, None], dtype=pl.Int64) + assert_series_equal(out, expected) + + # test logical type + s = pl.Series( + "a", + [ + [datetime.date(1999, 1, 1), datetime.date(2000, 1, 1)], + [datetime.date(2001, 10, 1), None], + [None, None], + ], + dtype=pl.Array(pl.Date, 2), + ) + out = s.arr.get(pl.Series([1, -2, 4])) + expected = pl.Series( + "a", + [datetime.date(2000, 1, 1), datetime.date(2001, 10, 1), None], + dtype=pl.Date, + ) + assert_series_equal(out, expected) + + +def test_arr_first_last() -> None: + s = pl.Series( + "a", + [[1, 2, 3], [None, 5, 6], [None, None, None]], + dtype=pl.Array(pl.Int64, 3), + ) + + first = s.arr.first() + expected_first = pl.Series( + "a", + [1, None, None], + dtype=pl.Int64, + ) + assert_series_equal(first, expected_first) + + last = s.arr.last() + expected_last = pl.Series( + "a", + [3, 6, None], + dtype=pl.Int64, + ) + assert_series_equal(last, expected_last)