Skip to content

Commit

Permalink
fix(python, rust): Block rounding/truncating to negative durations (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-sil authored Mar 21, 2024
1 parent 645ce62 commit e6f7cb5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
8 changes: 8 additions & 0 deletions crates/polars-time/src/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ pub trait PolarsRound {

impl PolarsRound for DatetimeChunked {
fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult<Self> {
if every.negative {
polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);

let func = match self.time_unit() {
Expand All @@ -27,6 +31,10 @@ impl PolarsRound for DatetimeChunked {

impl PolarsRound for DateChunked {
fn round(&self, every: Duration, offset: Duration, _tz: Option<&Tz>) -> PolarsResult<Self> {
if every.negative {
polars_bail!(ComputeError: "cannot round a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(self
.try_apply_values(|t| {
Expand Down
16 changes: 16 additions & 0 deletions crates/polars-time/src/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ impl PolarsTruncate for DatetimeChunked {
1 => {
if let Some(every) = every.get(0) {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);
self.0.try_apply_values(|timestamp| func(&w, timestamp, tz))
} else {
Expand All @@ -35,6 +39,10 @@ impl PolarsTruncate for DatetimeChunked {
match (opt_timestamp, opt_every) {
(Some(timestamp), Some(every)) => {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);
func(&w, timestamp, tz).map(Some)
},
Expand All @@ -58,6 +66,10 @@ impl PolarsTruncate for DateChunked {
1 => {
if let Some(every) = every.get(0) {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot truncate a Date to a negative duration")
}

let w = Window::new(every, every, offset);
self.try_apply_values(|t| {
const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY;
Expand All @@ -72,6 +84,10 @@ impl PolarsTruncate for DateChunked {
(Some(t), Some(every)) => {
const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY;
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot truncate a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(Some(
(w.truncate_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32,
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,37 @@ def test_truncate(
assert out.dt[-1] == stop


def test_truncate_negative() -> None:
"""Test that truncating to a negative duration gives a helpful error message."""
df = pl.DataFrame(
{
"date": [date(1895, 5, 7), date(1955, 11, 5)],
"datetime": [datetime(1895, 5, 7), datetime(1955, 11, 5)],
"duration": ["-1m", "1m"],
}
)

with pytest.raises(
ComputeError, match="cannot truncate a Date to a negative duration"
):
df.select(pl.col("date").dt.truncate("-1m"))

with pytest.raises(
ComputeError, match="cannot truncate a Datetime to a negative duration"
):
df.select(pl.col("datetime").dt.truncate("-1m"))

with pytest.raises(
ComputeError, match="cannot truncate a Date to a negative duration"
):
df.select(pl.col("date").dt.truncate(pl.col("duration")))

with pytest.raises(
ComputeError, match="cannot truncate a Datetime to a negative duration"
):
df.select(pl.col("datetime").dt.truncate(pl.col("duration")))


@pytest.mark.parametrize(
("time_unit", "every"),
[
Expand Down Expand Up @@ -542,6 +573,19 @@ def test_round(
assert out.dt[-1] == stop


def test_round_negative() -> None:
"""Test that rounding to a negative duration gives a helpful error message."""
with pytest.raises(
ComputeError, match="cannot round a Date to a negative duration"
):
pl.Series([date(1895, 5, 7)]).dt.round("-1m")

with pytest.raises(
ComputeError, match="cannot round a Datetime to a negative duration"
):
pl.Series([datetime(1895, 5, 7)]).dt.round("-1m")


@pytest.mark.parametrize(
("time_unit", "date_in_that_unit"),
[
Expand Down

0 comments on commit e6f7cb5

Please sign in to comment.