Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support weekend argument in business_day_count #15544

Merged
merged 5 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions crates/polars-ops/src/series/ops/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@ use polars_core::prelude::arity::binary_elementwise_values;
use polars_core::prelude::*;

/// Count the number of business days between `start` and `end`, excluding `end`.
pub fn business_day_count(start: &Series, end: &Series) -> PolarsResult<Series> {
///
/// # Arguments
/// - `start`: Series holding start dates.
/// - `end`: Series holding end dates.
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
pub fn business_day_count(
start: &Series,
end: &Series,
week_mask: [bool; 7],
) -> PolarsResult<Series> {
if !week_mask.iter().any(|&x| x) {
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}
let start_dates = start.date()?;
let end_dates = end.date()?;

// TODO: support customising weekdays
let week_mask: [bool; 7] = [true, true, true, true, true, false, false];
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;

let out = match (start_dates.len(), end_dates.len()) {
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ use crate::prelude::SeriesUdf;
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum BusinessFunction {
#[cfg(feature = "business")]
BusinessDayCount,
BusinessDayCount { week_mask: [bool; 7] },
}

impl Display for BusinessFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use BusinessFunction::*;
let s = match self {
#[cfg(feature = "business")]
&BusinessDayCount => "business_day_count",
&BusinessDayCount { .. } => "business_day_count",
};
write!(f, "{s}")
}
Expand All @@ -30,16 +30,16 @@ impl From<BusinessFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
use BusinessFunction::*;
match func {
#[cfg(feature = "business")]
BusinessDayCount => {
map_as_slice!(business_day_count)
BusinessDayCount { week_mask } => {
map_as_slice!(business_day_count, week_mask)
},
}
}
}

#[cfg(feature = "business")]
pub(super) fn business_day_count(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn business_day_count(s: &[Series], week_mask: [bool; 7]) -> PolarsResult<Series> {
let start = &s[0];
let end = &s[1];
polars_ops::prelude::business_day_count(start, end)
polars_ops::prelude::business_day_count(start, end, week_mask)
}
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/functions/business.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use super::*;

#[cfg(feature = "dtype-date")]
pub fn business_day_count(start: Expr, end: Expr) -> Expr {
pub fn business_day_count(start: Expr, end: Expr, week_mask: [bool; 7]) -> Expr {
let input = vec![start, end];

Expr::Function {
input,
function: FunctionExpr::Business(BusinessFunction::BusinessDayCount {}),
function: FunctionExpr::Business(BusinessFunction::BusinessDayCount { week_mask }),
options: FunctionOptions {
allow_rename: true,
..Default::default()
Expand Down
29 changes: 24 additions & 5 deletions py-polars/polars/functions/business.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable

from polars._utils.parse_expr_input import parse_as_expression
from polars._utils.wrap import wrap_expr
Expand All @@ -19,19 +19,21 @@
def business_day_count(
start: date | IntoExprColumn,
end: date | IntoExprColumn,
week_mask: Iterable[bool] = (True, True, True, True, True, False, False),
) -> Expr:
"""
Count the number of business days between `start` and `end` (not including `end`).

By default, Saturday and Sunday are excluded. The ability to
customise week mask and holidays is not yet implemented.

Parameters
----------
start
Start dates.
end
End dates.
week_mask
Which days of the week to count. The default is Monday to Friday.
If you wanted to count only Monday to Thursday, you would pass
`(True, True, True, True, False, False, False)`.

Returns
-------
Expand Down Expand Up @@ -62,7 +64,24 @@ def business_day_count(

Note how the two "count" columns differ due to the weekend (2020-01-04 - 2020-01-05)
not being counted by `business_day_count`.

You can pass a custom weekend - for example, if you only take Sunday off:

>>> week_mask = (True, True, True, True, True, True, False)
>>> df.with_columns(
... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(),
... business_day_count=pl.business_day_count("start", "end", week_mask),
... )
shape: (2, 4)
┌────────────┬────────────┬─────────────────┬────────────────────┐
│ start ┆ end ┆ total_day_count ┆ business_day_count │
│ --- ┆ --- ┆ --- ┆ --- │
│ date ┆ date ┆ i64 ┆ i32 │
╞════════════╪════════════╪═════════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 1 ┆ 1 │
│ 2020-01-02 ┆ 2020-01-10 ┆ 8 ┆ 7 │
└────────────┴────────────┴─────────────────┴────────────────────┘
"""
start_pyexpr = parse_as_expression(start)
end_pyexpr = parse_as_expression(end)
return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr))
return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr, week_mask))
4 changes: 2 additions & 2 deletions py-polars/src/functions/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use pyo3::prelude::*;
use crate::PyExpr;

#[pyfunction]
pub fn business_day_count(start: PyExpr, end: PyExpr) -> PyExpr {
pub fn business_day_count(start: PyExpr, end: PyExpr, week_mask: [bool; 7]) -> PyExpr {
let start = start.inner;
let end = end.inner;
dsl::business_day_count(start, end).into()
dsl::business_day_count(start, end, week_mask).into()
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import hypothesis.strategies as st
import numpy as np
from hypothesis import given, reject
from hypothesis import assume, given, reject

import polars as pl
from polars._utils.various import parse_version
Expand All @@ -13,17 +13,24 @@
@given(
start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
week_mask=st.lists(
st.sampled_from([True, False]),
min_size=7,
max_size=7,
),
)
def test_against_np_busday_count(
start: dt.date,
end: dt.date,
week_mask: tuple[bool, ...],
) -> None:
assume(any(week_mask))
result = (
pl.DataFrame({"start": [start], "end": [end]})
.select(n=pl.business_day_count("start", "end"))["n"]
.select(n=pl.business_day_count("start", "end", week_mask=week_mask))["n"]
.item()
)
expected = np.busday_count(start, end)
expected = np.busday_count(start, end, weekmask=week_mask)
if start > end and parse_version(np.__version__) < parse_version("1.25"):
# Bug in old versions of numpy
reject()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import date

import pytest

import polars as pl
from polars.testing import assert_series_equal

Expand Down Expand Up @@ -50,6 +52,45 @@ def test_business_day_count() -> None:
assert_series_equal(result, expected)


def test_business_day_count_w_week_mask() -> None:
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10)],
}
)
result = df.select(
business_day_count=pl.business_day_count(
"start", "end", week_mask=(True, True, True, True, True, True, False)
),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 7], pl.Int32)
assert_series_equal(result, expected)

result = df.select(
business_day_count=pl.business_day_count(
"start", "end", week_mask=(True, True, True, False, False, False, True)
),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 4], pl.Int32)
assert_series_equal(result, expected)


def test_business_day_count_w_week_mask_invalid() -> None:
with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"):
pl.business_day_count("start", "end", week_mask=(False, 0)) # type: ignore[arg-type]
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10)],
}
)
with pytest.raises(
pl.ComputeError, match="`week_mask` must have at least one business day"
):
df.select(pl.business_day_count("start", "end", week_mask=[False] * 7))


def test_business_day_count_schema() -> None:
lf = pl.LazyFrame(
{
Expand Down
Loading