Skip to content

Commit

Permalink
feat: support weekend argument in business_day_count
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 8, 2024
1 parent 5d82f0d commit 7fff1e8
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 22 deletions.
14 changes: 10 additions & 4 deletions crates/polars-ops/src/series/ops/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@ use polars_core::prelude::arity::binary_elementwise_values;
use polars_core::prelude::*;

/// Count the number of business days between `start` and `end`, excluding `end`.
pub fn business_day_count(start: &Series, end: &Series) -> PolarsResult<Series> {
///
/// # Arguments
/// - `start`: Series holding start dates.
/// - `end`: Series holding end dates.
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
pub fn business_day_count(
start: &Series,
end: &Series,
week_mask: [bool; 7],
) -> PolarsResult<Series> {
let start_dates = start.date()?;
let end_dates = end.date()?;

// TODO: support customising weekdays
let week_mask: [bool; 7] = [true, true, true, true, true, false, false];
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;

let out = match (start_dates.len(), end_dates.len()) {
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ use crate::prelude::SeriesUdf;
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum BusinessFunction {
#[cfg(feature = "business")]
BusinessDayCount,
BusinessDayCount { week_mask: [bool; 7] },
}

impl Display for BusinessFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use BusinessFunction::*;
let s = match self {
#[cfg(feature = "business")]
&BusinessDayCount => "business_day_count",
&BusinessDayCount { .. } => "business_day_count",
};
write!(f, "{s}")
}
Expand All @@ -30,16 +30,16 @@ impl From<BusinessFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
use BusinessFunction::*;
match func {
#[cfg(feature = "business")]
BusinessDayCount => {
map_as_slice!(business_day_count)
BusinessDayCount { week_mask } => {
map_as_slice!(business_day_count, week_mask)
},
}
}
}

#[cfg(feature = "business")]
pub(super) fn business_day_count(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn business_day_count(s: &[Series], week_mask: [bool; 7]) -> PolarsResult<Series> {
let start = &s[0];
let end = &s[1];
polars_ops::prelude::business_day_count(start, end)
polars_ops::prelude::business_day_count(start, end, week_mask)
}
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/functions/business.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use super::*;

#[cfg(feature = "dtype-date")]
pub fn business_day_count(start: Expr, end: Expr) -> Expr {
pub fn business_day_count(start: Expr, end: Expr, week_mask: [bool; 7]) -> Expr {
let input = vec![start, end];

Expr::Function {
input,
function: FunctionExpr::Business(BusinessFunction::BusinessDayCount {}),
function: FunctionExpr::Business(BusinessFunction::BusinessDayCount { week_mask }),
options: FunctionOptions {
allow_rename: true,
..Default::default()
Expand Down
65 changes: 59 additions & 6 deletions py-polars/polars/functions/business.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable

from polars._utils.parse_expr_input import parse_as_expression
from polars._utils.wrap import wrap_expr
Expand All @@ -13,25 +13,61 @@
from datetime import date

from polars import Expr
from polars.type_aliases import IntoExprColumn
from polars.type_aliases import DayOfWeek, IntoExprColumn

DAY_NAMES = (
"Mon",
"Tue",
"Wed",
"Thu",
"Fri",
"Sat",
"Sun",
)


def _make_week_mask(
weekend: Iterable[str] | None,
) -> tuple[bool, ...]:
if weekend is None:
return tuple([True] * 7)
if isinstance(weekend, str):
weekend_set = {weekend}
else:
weekend_set = set(weekend)
for day in weekend_set:
if day not in DAY_NAMES:
msg = f"Expected one of {DAY_NAMES}, got: {day}"
raise ValueError(msg)
return tuple(
[
False if v in weekend else True # noqa: SIM211
for v in DAY_NAMES
]
)


def business_day_count(
start: date | IntoExprColumn,
end: date | IntoExprColumn,
weekend: DayOfWeek | Iterable[DayOfWeek] | None = ("Sat", "Sun"),
) -> Expr:
"""
Count the number of business days between `start` and `end` (not including `end`).
By default, Saturday and Sunday are excluded. The ability to
customise week mask and holidays is not yet implemented.
Parameters
----------
start
Start dates.
end
End dates.
weekend
Which days of the week to exclude. The default is `('Sat', 'Sun')`, but you
can also pass, for example, `weekend=('Fri', 'Sat')`, `weekend='Sun'`,
or `weekend=None`.
Allowed values in the tuple are 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat',
and 'Sun'.
Returns
-------
Expand Down Expand Up @@ -62,7 +98,24 @@ def business_day_count(
Note how the two "count" columns differ due to the weekend (2020-01-04 - 2020-01-05)
not being counted by `business_day_count`.
You can pass a custom weekend - for example, if you only take Sunday off:
>>> df.with_columns(
... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(),
... business_day_count=pl.business_day_count("start", "end", weekend="Sun"),
... )
shape: (2, 4)
┌────────────┬────────────┬─────────────────┬────────────────────┐
│ start ┆ end ┆ total_day_count ┆ business_day_count │
│ --- ┆ --- ┆ --- ┆ --- │
│ date ┆ date ┆ i64 ┆ i32 │
╞════════════╪════════════╪═════════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 1 ┆ 1 │
│ 2020-01-02 ┆ 2020-01-10 ┆ 8 ┆ 7 │
└────────────┴────────────┴─────────────────┴────────────────────┘
"""
start_pyexpr = parse_as_expression(start)
end_pyexpr = parse_as_expression(end)
return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr))
week_mask = _make_week_mask(weekend)
return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr, week_mask))
9 changes: 9 additions & 0 deletions py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@
"horizontal",
"align",
]
DayOfWeek = Literal[
"Mon",
"Tue",
"Wed",
"Thu",
"Fri",
"Sat",
"Sun",
]
EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"]
Orientation: TypeAlias = Literal["col", "row"]
SearchSortedSide: TypeAlias = Literal["any", "left", "right"]
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/functions/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use pyo3::prelude::*;
use crate::PyExpr;

#[pyfunction]
pub fn business_day_count(start: PyExpr, end: PyExpr) -> PyExpr {
pub fn business_day_count(start: PyExpr, end: PyExpr, week_mask: [bool; 7]) -> PyExpr {
let start = start.inner;
let end = end.inner;
dsl::business_day_count(start, end).into()
dsl::business_day_count(start, end, week_mask).into()
}
16 changes: 14 additions & 2 deletions py-polars/tests/parametric/time_series/test_business_day_count.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
from __future__ import annotations

import datetime as dt
from typing import TYPE_CHECKING

import hypothesis.strategies as st
import numpy as np
from hypothesis import given, reject

import polars as pl
from polars._utils.various import parse_version
from polars.functions.business import _make_week_mask

if TYPE_CHECKING:
from polars.type_aliases import DayOfWeek


@given(
start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
weekend=st.lists(
st.sampled_from(["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]),
min_size=0,
max_size=6,
unique=True,
),
)
def test_against_np_busday_count(
start: dt.date,
end: dt.date,
weekend: list[DayOfWeek],
) -> None:
result = (
pl.DataFrame({"start": [start], "end": [end]})
.select(n=pl.business_day_count("start", "end"))["n"]
.select(n=pl.business_day_count("start", "end", weekend=weekend))["n"]
.item()
)
expected = np.busday_count(start, end)
expected = np.busday_count(start, end, weekmask=_make_week_mask(weekend))
if start > end and parse_version(np.__version__) < parse_version("1.25"):
# Bug in old versions of numpy
reject()
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/unit/functions/business/test_business_day_count.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import date

import pytest

import polars as pl
from polars.testing import assert_series_equal

Expand Down Expand Up @@ -50,6 +52,40 @@ def test_business_day_count() -> None:
assert_series_equal(result, expected)


def test_business_day_count_w_weekend() -> None:
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10)],
}
)
result = df.select(
business_day_count=pl.business_day_count("start", "end", weekend="Sun"),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 7], pl.Int32)
assert_series_equal(result, expected)
result = df.select(
business_day_count=pl.business_day_count(
"start", "end", weekend=("Thu", "Fri", "Sat")
),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 4], pl.Int32)
assert_series_equal(result, expected)
result = df.select(
business_day_count=pl.business_day_count("start", "end", weekend=None),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 8], pl.Int32)
assert_series_equal(result, expected)


def test_business_day_count_w_weekend_invalid() -> None:
msg = r"Expected one of \('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'\), got: cabbage"
with pytest.raises(ValueError, match=msg):
pl.business_day_count("start", "end", weekend="cabbage") # type: ignore[arg-type]
with pytest.raises(ValueError, match=msg):
pl.business_day_count("start", "end", weekend=("Sat", "cabbage")) # type: ignore[arg-type]


def test_business_day_count_schema() -> None:
lf = pl.LazyFrame(
{
Expand Down

0 comments on commit 7fff1e8

Please sign in to comment.