Skip to content

Commit

Permalink
sort, then exclude weekend + duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 10, 2024
1 parent 4b5d9b4 commit cabcbbd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
35 changes: 20 additions & 15 deletions crates/polars-ops/src/series/ops/business.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use ahash::HashSet;
use polars_core::prelude::arity::binary_elementwise_values;
use polars_core::prelude::*;

Expand All @@ -20,16 +19,7 @@ pub fn business_day_count(
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}

// De-dupe and sort holidays, and exclude non-business days.
let mut holidays: Vec<i32> = holidays
.iter()
.filter(|&x| *unsafe { week_mask.get_unchecked(weekday(*x)) })
.cloned()
.collect::<HashSet<_>>()
.into_iter()
.collect();
holidays.sort_unstable();

let holidays = normalise_holidays(holidays, &week_mask);
let start_dates = start.date()?;
let end_dates = end.date()?;
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
Expand Down Expand Up @@ -98,10 +88,10 @@ fn business_day_count_impl(
Ok(x) => x,
Err(x) => x,
} as i32;
let holidays_end = match holidays.binary_search(&end_date) {
Ok(x) => x,
Err(x) => x,
} as i32;
let holidays_end = match holidays[(holidays_begin as usize)..].binary_search(&end_date) {
Ok(x) => x as i32 + holidays_begin,
Err(x) => x as i32 + holidays_begin,
};

let mut start_weekday = weekday(start_date);
let diff = end_date - start_date;
Expand All @@ -126,6 +116,21 @@ fn business_day_count_impl(
}
}

/// Sort and deduplicate holidays and remove holidays that are not business days.
fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {
let mut holidays: Vec<i32> = holidays.to_vec();
holidays.sort_unstable();
let mut previous_holiday: Option<i32> = None;
holidays.retain(|&x| {
if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(weekday(x)) } {
return false;
}
previous_holiday = Some(x);
true
});
holidays
}

fn weekday(x: i32) -> usize {
// the first modulo might return a negative number, so we add 7 and take
// the modulo again so we're sure we have something between 0 (Monday)
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/functions/business/test_business_day_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,19 @@ def test_business_day_count_schema() -> None:
assert result.schema["business_day_count"] == pl.Int32
assert result.collect().schema["business_day_count"] == pl.Int32
assert 'col("start").business_day_count([col("end")])' in result.explain()


def test_business_day_count_w_holidays() -> None:
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10)],
}
)
result = df.select(
business_day_count=pl.business_day_count(
"start", "end", holidays=[date(2020, 1, 1)]
),
)["business_day_count"]
expected = pl.Series("business_day_count", [0, 6], pl.Int32)
assert_series_equal(result, expected)

0 comments on commit cabcbbd

Please sign in to comment.