Skip to content

Commit

Permalink
feat(python): Impl and dispatch arr.first/last to get (pola-rs#13536)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Jan 9, 2024
1 parent 85e4f53 commit 8730ced
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 2 deletions.
53 changes: 53 additions & 0 deletions crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray};
use crate::legacy::compute::take::take_unchecked;
use crate::legacy::prelude::*;
use crate::legacy::utils::CustomIterTools;

fn sub_fixed_size_list_get_indexes_literal(width: usize, len: usize, index: i64) -> IdxArr {
(0..len)
.map(|i| {
if index >= width as i64 {
return None;
}

index
.negative_to_usize(width)
.map(|idx| (idx + i * width) as IdxSize)
})
.collect_trusted()
}

fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray<i64>) -> IdxArr {
index
.iter()
.enumerate()
.map(|(i, idx)| {
if let Some(idx) = idx {
if *idx >= width as i64 {
return None;
}

idx.negative_to_usize(width)
.map(|idx| (idx + i * width) as IdxSize)
} else {
None
}
})
.collect_trusted()
}

pub fn sub_fixed_size_list_get_literal(arr: &FixedSizeListArray, index: i64) -> ArrayRef {
let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index);
let values = arr.values();
// Safety:
// the indices we generate are in bounds
unsafe { take_unchecked(&**values, &take_by) }
}

pub fn sub_fixed_size_list_get(arr: &FixedSizeListArray, index: &PrimitiveArray<i64>) -> ArrayRef {
let take_by = sub_fixed_size_list_get_indexes(arr.size(), index);
let values = arr.values();
// Safety:
// the indices we generate are in bounds
unsafe { take_unchecked(&**values, &take_by) }
}
1 change: 1 addition & 0 deletions crates/polars-arrow/src/legacy/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod agg_mean;
pub mod atan2;
pub mod concatenate;
pub mod ewm;
pub mod fixed_size_list;
pub mod float;
pub mod list;
pub mod list_bytes_iter;
Expand Down
23 changes: 22 additions & 1 deletion crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::error::Error;

use arrow::array::Array;
use arrow::compute::utils::combine_validities_and;
use polars_error::PolarsResult;

use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::prelude::{ChunkedArray, PolarsDataType, Series};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
Expand Down Expand Up @@ -444,6 +445,26 @@ where
lhs.copy_with_chunks(chunks, keep_sorted, keep_fast_explode)
}

#[inline]
pub fn binary_to_series<T, U, F>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> PolarsResult<Series>
where
T: PolarsDataType,
U: PolarsDataType,
F: FnMut(&T::Array, &U::Array) -> Box<dyn Array>,
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let chunks = lhs
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
.collect::<Vec<_>>();
Series::try_from((lhs.name(), chunks))
}

/// Applies a kernel that produces `ArrayRef` of the same type.
///
/// # Safety
Expand Down
43 changes: 43 additions & 0 deletions crates/polars-ops/src/chunked_array/array/get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use arrow::legacy::kernels::fixed_size_list::{
sub_fixed_size_list_get, sub_fixed_size_list_get_literal,
};
use polars_core::datatypes::ArrayChunked;
use polars_core::prelude::arity::binary_to_series;

use super::*;

fn array_get_literal(ca: &ArrayChunked, idx: i64) -> PolarsResult<Series> {
let chunks = ca
.downcast_iter()
.map(|arr| sub_fixed_size_list_get_literal(arr, idx))
.collect::<Vec<_>>();
Series::try_from((ca.name(), chunks))
.unwrap()
.cast(&ca.inner_dtype())
}

/// Get the value by literal index in the array.
/// So index `0` would return the first item of every sub-array
/// and index `-1` would return the last item of every sub-array
/// if an index is out of bounds, it will return a `None`.
pub fn array_get(ca: &ArrayChunked, index: &Int64Chunked) -> PolarsResult<Series> {
match index.len() {
1 => {
let index = index.get(0);
if let Some(index) = index {
array_get_literal(ca, index)
} else {
polars_bail!(ComputeError: "unexpected null index received in `arr.get`")
}
},
len if len == ca.len() => {
let out = binary_to_series(ca, index, |arr, idx| sub_fixed_size_list_get(arr, idx));
out?.cast(&ca.inner_dtype())
},
len => polars_bail!(
ComputeError:
"`arr.get` expression got an index array of length {} while the array has {} elements",
len, ca.len()
),
}
}
1 change: 1 addition & 0 deletions crates/polars-ops/src/chunked_array/array/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(feature = "array_any_all")]
mod any_all;
mod get;
mod min_max;
mod namespace;
mod sum_mean;
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::*;
use crate::chunked_array::array::sum_mean::sum_with_nulls;
#[cfg(feature = "array_any_all")]
use crate::prelude::array::any_all::{array_all, array_any};
use crate::prelude::array::get::array_get;
use crate::prelude::array::sum_mean::sum_array_numerical;
use crate::series::ArgAgg;

Expand Down Expand Up @@ -92,6 +93,11 @@ pub trait ArrayNameSpace: AsArray {
opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize))
})
}

fn array_get(&self, index: &Int64Chunked) -> PolarsResult<Series> {
let ca = self.as_array();
array_get(ca, index)
}
}

impl ArrayNameSpace for ArrayChunked {}
10 changes: 10 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,14 @@ impl ArrayNameSpace {
self.0
.map_private(FunctionExpr::ArrayExpr(ArrayFunction::ArgMax))
}

/// Get items in every sub-array by index.
pub fn get(self, index: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::Get),
&[index],
false,
false,
)
}
}
13 changes: 12 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use polars_ops::chunked_array::array::*;

use super::*;
use crate::map;
use crate::{map, map_as_slice};

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand All @@ -19,6 +19,7 @@ pub enum ArrayFunction {
Reverse,
ArgMin,
ArgMax,
Get,
}

impl ArrayFunction {
Expand All @@ -34,6 +35,7 @@ impl ArrayFunction {
Sort(_) => mapper.with_same_dtype(),
Reverse => mapper.with_same_dtype(),
ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
Get => mapper.map_to_list_and_array_inner_dtype(),
}
}
}
Expand Down Expand Up @@ -63,6 +65,7 @@ impl Display for ArrayFunction {
Reverse => "reverse",
ArgMin => "arg_min",
ArgMax => "arg_max",
Get => "get",
};
write!(f, "arr.{name}")
}
Expand All @@ -85,6 +88,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Reverse => map!(reverse),
ArgMin => map!(arg_min),
ArgMax => map!(arg_max),
Get => map_as_slice!(get),
}
}
}
Expand Down Expand Up @@ -141,3 +145,10 @@ pub(super) fn arg_min(s: &Series) -> PolarsResult<Series> {
pub(super) fn arg_max(s: &Series) -> PolarsResult<Series> {
Ok(s.array()?.array_arg_max().into_series())
}

pub(super) fn get(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].array()?;
let index = s[1].cast(&DataType::Int64)?;
let index = index.i64().unwrap();
ca.array_get(index)
}
3 changes: 3 additions & 0 deletions py-polars/docs/source/reference/expressions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ The following methods are available under the `expr.arr` attribute.
Expr.arr.reverse
Expr.arr.arg_min
Expr.arr.arg_max
Expr.arr.get
Expr.arr.first
Expr.arr.last
3 changes: 3 additions & 0 deletions py-polars/docs/source/reference/series/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ The following methods are available under the `Series.arr` attribute.
Series.arr.reverse
Series.arr.arg_min
Series.arr.arg_max
Series.arr.get
Series.arr.first
Series.arr.last
87 changes: 87 additions & 0 deletions py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from typing import TYPE_CHECKING

from polars.utils._parse_expr_input import parse_as_expression
from polars.utils._wrap import wrap_expr

if TYPE_CHECKING:
from polars import Expr
from polars.type_aliases import IntoExprColumn


class ExprArrayNameSpace:
Expand Down Expand Up @@ -340,3 +342,88 @@ def arg_max(self) -> Expr:
"""
return wrap_expr(self._pyexpr.arr_arg_max())

def get(self, index: int | IntoExprColumn) -> Expr:
"""
Get the value by index in the sub-arrays.
So index `0` would return the first item of every sublist
and index `-1` would return the last item of every sublist
if an index is out of bounds, it will return a `None`.
Parameters
----------
index
Index to return per sub-array
Examples
--------
>>> df = pl.DataFrame(
... {"arr": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "idx": [1, -2, 4]},
... schema={"arr": pl.Array(pl.Int32, 3), "idx": pl.Int32},
... )
>>> df.with_columns(get=pl.col("arr").arr.get("idx"))
shape: (3, 3)
┌───────────────┬─────┬──────┐
│ arr ┆ idx ┆ get │
│ --- ┆ --- ┆ --- │
│ array[i32, 3] ┆ i32 ┆ i32 │
╞═══════════════╪═════╪══════╡
│ [1, 2, 3] ┆ 1 ┆ 2 │
│ [4, 5, 6] ┆ -2 ┆ 5 │
│ [7, 8, 9] ┆ 4 ┆ null │
└───────────────┴─────┴──────┘
"""
index = parse_as_expression(index)
return wrap_expr(self._pyexpr.arr_get(index))

def first(self) -> Expr:
"""
Get the first value of the sub-arrays.
Examples
--------
>>> df = pl.DataFrame(
... {"a": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]},
... schema={"a": pl.Array(pl.Int32, 3)},
... )
>>> df.with_columns(first=pl.col("a").arr.first())
shape: (3, 2)
┌───────────────┬───────┐
│ a ┆ first │
│ --- ┆ --- │
│ array[i32, 3] ┆ i32 │
╞═══════════════╪═══════╡
│ [1, 2, 3] ┆ 1 │
│ [4, 5, 6] ┆ 4 │
│ [7, 8, 9] ┆ 7 │
└───────────────┴───────┘
"""
return self.get(0)

def last(self) -> Expr:
"""
Get the last value of the sub-arrays.
Examples
--------
>>> df = pl.DataFrame(
... {"a": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]},
... schema={"a": pl.Array(pl.Int32, 3)},
... )
>>> df.with_columns(last=pl.col("a").arr.last())
shape: (3, 2)
┌───────────────┬──────┐
│ a ┆ last │
│ --- ┆ --- │
│ array[i32, 3] ┆ i32 │
╞═══════════════╪══════╡
│ [1, 2, 3] ┆ 3 │
│ [4, 5, 6] ┆ 6 │
│ [7, 8, 9] ┆ 9 │
└───────────────┴──────┘
"""
return self.get(-1)
Loading

0 comments on commit 8730ced

Please sign in to comment.