diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs index 19562072fa95..ea786c96e850 100644 --- a/crates/polars-ops/src/series/ops/business.rs +++ b/crates/polars-ops/src/series/ops/business.rs @@ -1,4 +1,3 @@ -use ahash::HashSet; use polars_core::prelude::arity::binary_elementwise_values; use polars_core::prelude::*; @@ -20,16 +19,7 @@ pub fn business_day_count( polars_bail!(ComputeError:"`week_mask` must have at least one business day"); } - // De-dupe and sort holidays, and exclude non-business days. - let mut holidays: Vec = holidays - .iter() - .filter(|&x| *unsafe { week_mask.get_unchecked(weekday(*x)) }) - .cloned() - .collect::>() - .into_iter() - .collect(); - holidays.sort_unstable(); - + let holidays = normalise_holidays(holidays, &week_mask); let start_dates = start.date()?; let end_dates = end.date()?; let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; @@ -98,10 +88,10 @@ fn business_day_count_impl( Ok(x) => x, Err(x) => x, } as i32; - let holidays_end = match holidays.binary_search(&end_date) { - Ok(x) => x, - Err(x) => x, - } as i32; + let holidays_end = match holidays[(holidays_begin as usize)..].binary_search(&end_date) { + Ok(x) => x as i32 + holidays_begin, + Err(x) => x as i32 + holidays_begin, + }; let mut start_weekday = weekday(start_date); let diff = end_date - start_date; @@ -126,6 +116,21 @@ fn business_day_count_impl( } } +/// Sort and deduplicate holidays and remove holidays that are not business days. +fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec { + let mut holidays: Vec = holidays.to_vec(); + holidays.sort_unstable(); + let mut previous_holiday: Option = None; + holidays.retain(|&x| { + if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(weekday(x)) } { + return false; + } + previous_holiday = Some(x); + true + }); + holidays +} + fn weekday(x: i32) -> usize { // the first modulo might return a negative number, so we add 7 and take // the modulo again so we're sure we have something between 0 (Monday) diff --git a/py-polars/tests/unit/functions/business/test_business_day_count.py b/py-polars/tests/unit/functions/business/test_business_day_count.py index 13a1a05dbb7b..94ce1226c15b 100644 --- a/py-polars/tests/unit/functions/business/test_business_day_count.py +++ b/py-polars/tests/unit/functions/business/test_business_day_count.py @@ -104,3 +104,19 @@ def test_business_day_count_schema() -> None: assert result.schema["business_day_count"] == pl.Int32 assert result.collect().schema["business_day_count"] == pl.Int32 assert 'col("start").business_day_count([col("end")])' in result.explain() + + +def test_business_day_count_w_holidays() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = df.select( + business_day_count=pl.business_day_count( + "start", "end", holidays=[date(2020, 1, 1)] + ), + )["business_day_count"] + expected = pl.Series("business_day_count", [0, 6], pl.Int32) + assert_series_equal(result, expected)