From 2a42096b6a2596675955a9216a97508dd8332b41 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 8 Apr 2024 17:09:45 +0100 Subject: [PATCH] feat: add business_day_count function (#15512) --- crates/polars-lazy/Cargo.toml | 1 + crates/polars-ops/Cargo.toml | 1 + crates/polars-ops/src/series/ops/business.rs | 97 +++++++++++++++++++ crates/polars-ops/src/series/ops/mod.rs | 4 + crates/polars-plan/Cargo.toml | 2 + .../src/dsl/function_expr/business.rs | 45 +++++++++ .../polars-plan/src/dsl/function_expr/mod.rs | 12 +++ .../src/dsl/function_expr/schema.rs | 2 + .../polars-plan/src/dsl/functions/business.rs | 15 +++ crates/polars-plan/src/dsl/functions/mod.rs | 4 + crates/polars/Cargo.toml | 1 + py-polars/Cargo.toml | 1 + .../reference/expressions/functions.rst | 1 + py-polars/polars/__init__.py | 2 + py-polars/polars/functions/__init__.py | 2 + py-polars/polars/functions/business.py | 68 +++++++++++++ py-polars/src/functions/business.rs | 11 +++ py-polars/src/functions/mod.rs | 2 + py-polars/src/lib.rs | 4 + .../time_series/test_business_day_count.py | 30 ++++++ .../business/test_business_day_count.py | 65 +++++++++++++ 21 files changed, 370 insertions(+) create mode 100644 crates/polars-ops/src/series/ops/business.rs create mode 100644 crates/polars-plan/src/dsl/function_expr/business.rs create mode 100644 crates/polars-plan/src/dsl/functions/business.rs create mode 100644 py-polars/polars/functions/business.py create mode 100644 py-polars/src/functions/business.rs create mode 100644 py-polars/tests/parametric/time_series/test_business_day_count.py create mode 100644 py-polars/tests/unit/functions/business/test_business_day_count.py diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index e71846c6ea83..4aca12a5eeb0 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -96,6 +96,7 @@ is_between = ["polars-plan/is_between"] is_unique = ["polars-plan/is_unique"] cross_join = ["polars-plan/cross_join", "polars-pipe?/cross_join", "polars-ops/cross_join"] asof_join = ["polars-plan/asof_join", "polars-time", "polars-ops/asof_join"] +business = ["polars-plan/business"] concat_str = ["polars-plan/concat_str"] range = ["polars-plan/range"] mode = ["polars-plan/mode"] diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 2caf329d430a..19eb92154fe3 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -75,6 +75,7 @@ is_unique = [] unique_counts = [] is_between = [] approx_unique = [] +business = ["dtype-date"] fused = [] cutqcut = ["dtype-categorical", "dtype-struct"] rle = ["dtype-struct"] diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs new file mode 100644 index 000000000000..5b792453b5c4 --- /dev/null +++ b/crates/polars-ops/src/series/ops/business.rs @@ -0,0 +1,97 @@ +use polars_core::prelude::arity::binary_elementwise_values; +use polars_core::prelude::*; + +/// Count the number of business days between `start` and `end`, excluding `end`. +pub fn business_day_count(start: &Series, end: &Series) -> PolarsResult { + let start_dates = start.date()?; + let end_dates = end.date()?; + + // TODO: support customising weekdays + let week_mask: [bool; 7] = [true, true, true, true, true, false, false]; + let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; + + let out = match (start_dates.len(), end_dates.len()) { + (_, 1) => { + if let Some(end_date) = end_dates.get(0) { + start_dates.apply_values(|start_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + ) + }) + } else { + Int32Chunked::full_null(start_dates.name(), start_dates.len()) + } + }, + (1, _) => { + if let Some(start_date) = start_dates.get(0) { + end_dates.apply_values(|end_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + ) + }) + } else { + Int32Chunked::full_null(start_dates.name(), end_dates.len()) + } + }, + _ => binary_elementwise_values(start_dates, end_dates, |start_date, end_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + ) + }), + }; + Ok(out.into_series()) +} + +/// Ported from: +/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L355-L433 +fn business_day_count_impl( + mut start_date: i32, + mut end_date: i32, + week_mask: &[bool; 7], + n_business_days_in_week_mask: i32, +) -> i32 { + let swapped = start_date > end_date; + if swapped { + (start_date, end_date) = (end_date, start_date); + start_date += 1; + end_date += 1; + } + + let mut start_weekday = weekday(start_date); + let diff = end_date - start_date; + let whole_weeks = diff / 7; + let mut count = 0; + count += whole_weeks * n_business_days_in_week_mask; + start_date += whole_weeks * 7; + while start_date < end_date { + if unsafe { *week_mask.get_unchecked(start_weekday) } { + count += 1; + } + start_date += 1; + start_weekday += 1; + if start_weekday >= 7 { + start_weekday = 0; + } + } + if swapped { + -count + } else { + count + } +} + +fn weekday(x: i32) -> usize { + // the first modulo might return a negative number, so we add 7 and take + // the modulo again so we're sure we have something between 0 (Monday) + // and 6 (Sunday) + (((x - 4) % 7 + 7) % 7) as usize +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 9670a296e95b..0ebcff5daace 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -5,6 +5,8 @@ mod approx_algo; #[cfg(feature = "approx_unique")] mod approx_unique; mod arg_min_max; +#[cfg(feature = "business")] +mod business; mod clip; #[cfg(feature = "cum_agg")] mod cum_agg; @@ -65,6 +67,8 @@ pub use approx_algo::*; #[cfg(feature = "approx_unique")] pub use approx_unique::*; pub use arg_min_max::ArgAgg; +#[cfg(feature = "business")] +pub use business::*; pub use clip::*; #[cfg(feature = "cum_agg")] pub use cum_agg::*; diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 7a09c8c53173..01ddc96a6f33 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -110,6 +110,7 @@ is_between = ["polars-ops/is_between"] cross_join = ["polars-ops/cross_join"] asof_join = ["polars-time", "polars-ops/asof_join"] concat_str = [] +business = ["polars-ops/business"] range = [] mode = ["polars-ops/mode"] cum_agg = ["polars-ops/cum_agg"] @@ -252,6 +253,7 @@ features = [ "ciborium", "dtype-decimal", "arg_where", + "business", "range", "meta", "hive_partitions", diff --git a/crates/polars-plan/src/dsl/function_expr/business.rs b/crates/polars-plan/src/dsl/function_expr/business.rs new file mode 100644 index 000000000000..f9a38b1165cc --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/business.rs @@ -0,0 +1,45 @@ +use std::fmt::{Display, Formatter}; + +use polars_core::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::dsl::SpecialEq; +use crate::map_as_slice; +use crate::prelude::SeriesUdf; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum BusinessFunction { + #[cfg(feature = "business")] + BusinessDayCount, +} + +impl Display for BusinessFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use BusinessFunction::*; + let s = match self { + #[cfg(feature = "business")] + &BusinessDayCount => "business_day_count", + }; + write!(f, "{s}") + } +} +impl From for SpecialEq> { + fn from(func: BusinessFunction) -> Self { + use BusinessFunction::*; + match func { + #[cfg(feature = "business")] + BusinessDayCount => { + map_as_slice!(business_day_count) + }, + } + } +} + +#[cfg(feature = "business")] +pub(super) fn business_day_count(s: &[Series]) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + polars_ops::prelude::business_day_count(start, end) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index ae44125be688..82d04e7da55f 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -7,6 +7,8 @@ mod array; mod binary; mod boolean; mod bounds; +#[cfg(feature = "business")] +mod business; #[cfg(feature = "dtype-categorical")] mod cat; #[cfg(feature = "round_series")] @@ -81,6 +83,8 @@ use serde::{Deserialize, Serialize}; pub(crate) use self::binary::BinaryFunction; pub use self::boolean::BooleanFunction; +#[cfg(feature = "business")] +pub(super) use self::business::BusinessFunction; #[cfg(feature = "dtype-categorical")] pub(crate) use self::cat::CategoricalFunction; #[cfg(feature = "temporal")] @@ -117,6 +121,8 @@ pub enum FunctionExpr { // Other expressions Boolean(BooleanFunction), + #[cfg(feature = "business")] + Business(BusinessFunction), #[cfg(feature = "abs")] Abs, Negate, @@ -349,6 +355,8 @@ impl Hash for FunctionExpr { // Other expressions Boolean(f) => f.hash(state), + #[cfg(feature = "business")] + Business(f) => f.hash(state), Pow(f) => f.hash(state), #[cfg(feature = "search_sorted")] SearchSorted(f) => f.hash(state), @@ -557,6 +565,8 @@ impl Display for FunctionExpr { // Other expressions Boolean(func) => return write!(f, "{func}"), + #[cfg(feature = "business")] + Business(func) => return write!(f, "{func}"), #[cfg(feature = "abs")] Abs => "abs", Negate => "negate", @@ -815,6 +825,8 @@ impl From for SpecialEq> { // Other expressions Boolean(func) => func.into(), + #[cfg(feature = "business")] + Business(func) => func.into(), #[cfg(feature = "abs")] Abs => map!(abs::abs), Negate => map!(dispatch::negate), diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 187c4d2783cb..3082527e829d 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -27,6 +27,8 @@ impl FunctionExpr { // Other expressions Boolean(func) => func.get_field(mapper), + #[cfg(feature = "business")] + Business(_) => mapper.with_dtype(DataType::Int32), #[cfg(feature = "abs")] Abs => mapper.with_same_dtype(), Negate => mapper.with_same_dtype(), diff --git a/crates/polars-plan/src/dsl/functions/business.rs b/crates/polars-plan/src/dsl/functions/business.rs new file mode 100644 index 000000000000..4bfdcc0b20cc --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/business.rs @@ -0,0 +1,15 @@ +use super::*; + +#[cfg(feature = "dtype-date")] +pub fn business_day_count(start: Expr, end: Expr) -> Expr { + let input = vec![start, end]; + + Expr::Function { + input, + function: FunctionExpr::Business(BusinessFunction::BusinessDayCount {}), + options: FunctionOptions { + allow_rename: true, + ..Default::default() + }, + } +} diff --git a/crates/polars-plan/src/dsl/functions/mod.rs b/crates/polars-plan/src/dsl/functions/mod.rs index 9eca113f33d0..8b8fe24c7163 100644 --- a/crates/polars-plan/src/dsl/functions/mod.rs +++ b/crates/polars-plan/src/dsl/functions/mod.rs @@ -2,6 +2,8 @@ //! //! Functions on expressions that might be useful. mod arity; +#[cfg(feature = "business")] +mod business; #[cfg(feature = "dtype-struct")] mod coerce; mod concat; @@ -18,6 +20,8 @@ mod syntactic_sugar; mod temporal; pub use arity::*; +#[cfg(all(feature = "business", feature = "dtype-date"))] +pub use business::*; #[cfg(feature = "dtype-struct")] pub use coerce::*; pub use concat::*; diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 444d7b7cc947..027213677801 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -129,6 +129,7 @@ array_any_all = ["polars-lazy?/array_any_all", "dtype-array"] asof_join = ["polars-lazy?/asof_join", "polars-ops/asof_join"] bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"] +business = ["polars-lazy?/business", "polars-ops/business"] checked_arithmetic = ["polars-core/checked_arithmetic"] chunked_ids = ["polars-ops?/chunked_ids"] coalesce = ["polars-lazy?/coalesce"] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 9fec55932484..e8f80d83dc06 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -38,6 +38,7 @@ features = [ "abs", "approx_unique", "arg_where", + "business", "concat_str", "cum_agg", "cumulative_eval", diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index d240454e136b..546c44360332 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -23,6 +23,7 @@ These functions are available from the Polars module root and can be used as exp arctan2d arg_sort_by arg_where + business_day_count coalesce concat_list concat_str diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 492c3e437f63..830dc506c0ff 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -106,6 +106,7 @@ arctan2d, arg_sort_by, arg_where, + business_day_count, coalesce, col, collect_all, @@ -330,6 +331,7 @@ # polars.functions "align_frames", "arg_where", + "business_day_count", "concat", "date_range", "date_ranges", diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index 935cdfe159a6..048587300c21 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -25,6 +25,7 @@ from polars.functions.as_datatype import date_ as date from polars.functions.as_datatype import datetime_ as datetime from polars.functions.as_datatype import time_ as time +from polars.functions.business import business_day_count from polars.functions.col import col from polars.functions.eager import align_frames, concat from polars.functions.lazy import ( @@ -124,6 +125,7 @@ "arctan2", "arctan2d", "arg_sort_by", + "business_day_count", "coalesce", "col", "collect_all", diff --git a/py-polars/polars/functions/business.py b/py-polars/polars/functions/business.py new file mode 100644 index 000000000000..ae5791fde2a6 --- /dev/null +++ b/py-polars/polars/functions/business.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING + +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr + +if TYPE_CHECKING: + from datetime import date + + from polars import Expr + from polars.type_aliases import IntoExprColumn + + +def business_day_count( + start: date | IntoExprColumn, + end: date | IntoExprColumn, +) -> Expr: + """ + Count the number of business days between `start` and `end` (not including `end`). + + By default, Saturday and Sunday are excluded. The ability to + customise week mask and holidays is not yet implemented. + + Parameters + ---------- + start + Start dates. + end + End dates. + + Returns + ------- + Expr + + Examples + -------- + >>> from datetime import date + >>> df = pl.DataFrame( + ... { + ... "start": [date(2020, 1, 1), date(2020, 1, 2)], + ... "end": [date(2020, 1, 2), date(2020, 1, 10)], + ... } + ... ) + >>> df.with_columns( + ... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(), + ... business_day_count=pl.business_day_count("start", "end"), + ... ) + shape: (2, 4) + ┌────────────┬────────────┬─────────────────┬────────────────────┐ + │ start ┆ end ┆ total_day_count ┆ business_day_count │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ date ┆ date ┆ i64 ┆ i32 │ + ╞════════════╪════════════╪═════════════════╪════════════════════╡ + │ 2020-01-01 ┆ 2020-01-02 ┆ 1 ┆ 1 │ + │ 2020-01-02 ┆ 2020-01-10 ┆ 8 ┆ 6 │ + └────────────┴────────────┴─────────────────┴────────────────────┘ + + Note how the two "count" columns differ due to the weekend (2020-01-04 - 2020-01-05) + not being counted by `business_day_count`. + """ + start_pyexpr = parse_as_expression(start) + end_pyexpr = parse_as_expression(end) + return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr)) diff --git a/py-polars/src/functions/business.rs b/py-polars/src/functions/business.rs new file mode 100644 index 000000000000..246f902b895a --- /dev/null +++ b/py-polars/src/functions/business.rs @@ -0,0 +1,11 @@ +use polars::lazy::dsl; +use pyo3::prelude::*; + +use crate::PyExpr; + +#[pyfunction] +pub fn business_day_count(start: PyExpr, end: PyExpr) -> PyExpr { + let start = start.inner; + let end = end.inner; + dsl::business_day_count(start, end).into() +} diff --git a/py-polars/src/functions/mod.rs b/py-polars/src/functions/mod.rs index 56f6377a3669..0bb5e55ea23c 100644 --- a/py-polars/src/functions/mod.rs +++ b/py-polars/src/functions/mod.rs @@ -1,4 +1,5 @@ mod aggregation; +mod business; mod eager; mod io; mod lazy; @@ -10,6 +11,7 @@ mod string_cache; mod whenthen; pub use aggregation::*; +pub use business::*; pub use eager::*; pub use io::*; pub use lazy::*; diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 787e747a8b56..00193e5158c1 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -131,6 +131,10 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::time_ranges)) .unwrap(); + // Functions - business + m.add_wrapped(wrap_pyfunction!(functions::business_day_count)) + .unwrap(); + // Functions - aggregation m.add_wrapped(wrap_pyfunction!(functions::all_horizontal)) .unwrap(); diff --git a/py-polars/tests/parametric/time_series/test_business_day_count.py b/py-polars/tests/parametric/time_series/test_business_day_count.py new file mode 100644 index 000000000000..0cb1bf95df33 --- /dev/null +++ b/py-polars/tests/parametric/time_series/test_business_day_count.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import datetime as dt + +import hypothesis.strategies as st +import numpy as np +from hypothesis import given, reject + +import polars as pl +from polars._utils.various import parse_version + + +@given( + start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), +) +def test_against_np_busday_count( + start: dt.date, + end: dt.date, +) -> None: + result = ( + pl.DataFrame({"start": [start], "end": [end]}) + .select(n=pl.business_day_count("start", "end"))["n"] + .item() + ) + expected = np.busday_count(start, end) + if start > end and parse_version(np.__version__) < parse_version("1.25"): + # Bug in old versions of numpy + reject() + assert result == expected diff --git a/py-polars/tests/unit/functions/business/test_business_day_count.py b/py-polars/tests/unit/functions/business/test_business_day_count.py new file mode 100644 index 000000000000..74befbd3268b --- /dev/null +++ b/py-polars/tests/unit/functions/business/test_business_day_count.py @@ -0,0 +1,65 @@ +from datetime import date + +import polars as pl +from polars.testing import assert_series_equal + + +def test_business_day_count() -> None: + # (Expression, expression) + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = df.select( + business_day_count=pl.business_day_count("start", "end"), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 6], pl.Int32) + assert_series_equal(result, expected) + + # (Expression, scalar) + result = df.select( + business_day_count=pl.business_day_count("start", date(2020, 1, 10)), + )["business_day_count"] + expected = pl.Series("business_day_count", [7, 6], pl.Int32) + assert_series_equal(result, expected) + result = df.select( + business_day_count=pl.business_day_count("start", pl.lit(None, dtype=pl.Date)), + )["business_day_count"] + expected = pl.Series("business_day_count", [None, None], pl.Int32) + assert_series_equal(result, expected) + + # (Scalar, expression) + result = df.select( + business_day_count=pl.business_day_count(date(2020, 1, 1), "end"), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 7], pl.Int32) + assert_series_equal(result, expected) + result = df.select( + business_day_count=pl.business_day_count(pl.lit(None, dtype=pl.Date), "end"), + )["business_day_count"] + expected = pl.Series("business_day_count", [None, None], pl.Int32) + assert_series_equal(result, expected) + + # (Scalar, scalar) + result = df.select( + business_day_count=pl.business_day_count(date(2020, 1, 1), date(2020, 1, 10)), + )["business_day_count"] + expected = pl.Series("business_day_count", [7], pl.Int32) + assert_series_equal(result, expected) + + +def test_business_day_count_schema() -> None: + lf = pl.LazyFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = lf.select( + business_day_count=pl.business_day_count("start", "end"), + ) + assert result.schema["business_day_count"] == pl.Int32 + assert result.collect().schema["business_day_count"] == pl.Int32 + assert 'col("start").business_day_count([col("end")])' in result.explain()