Skip to content

Commit

Permalink
feat: Initial SQL support for INTERVAL strings (#16732)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Jun 5, 2024
1 parent a2a4157 commit a3f4ad9
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 36 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions crates/polars-plan/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,17 @@ impl Literal for ChronoDuration {
}
}

#[cfg(feature = "dtype-duration")]
impl Literal for Duration {
fn lit(self) -> Expr {
let ns = self.duration_ns();
Expr::Literal(LiteralValue::Duration(
if self.negative() { -ns } else { ns },
TimeUnit::Nanoseconds,
))
}
}

#[cfg(feature = "dtype-datetime")]
impl Literal for NaiveDate {
fn lit(self) -> Expr {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-ops = { workspace = true }
polars-plan = { workspace = true }
polars-time = { workspace = true }

hex = { workspace = true }
once_cell = { workspace = true }
Expand Down
40 changes: 39 additions & 1 deletion crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_lazy::prelude::*;
use polars_ops::series::SeriesReshape;
use polars_plan::prelude::typed_lit;
use polars_plan::prelude::LiteralValue::Null;
use polars_time::Duration;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use regex::{Regex, RegexBuilder};
Expand All @@ -17,7 +18,7 @@ use sqlparser::ast::ExactNumberInfo;
use sqlparser::ast::{
ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat,
DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident,
JoinConstraint, ObjectName, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo,
Interval, JoinConstraint, ObjectName, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo,
TrimWhereField, UnaryOperator, Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
Expand Down Expand Up @@ -259,6 +260,7 @@ impl SQLExprVisitor<'_> {
subquery,
negated,
} => self.visit_in_subquery(expr, subquery, *negated),
SQLExpr::Interval(interval) => self.visit_interval(interval),
SQLExpr::IsDistinctFrom(e1, e2) => {
Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
},
Expand Down Expand Up @@ -409,6 +411,42 @@ impl SQLExprVisitor<'_> {
}
}

fn visit_interval(&self, interval: &Interval) -> PolarsResult<Expr> {
if interval.last_field.is_some()
|| interval.leading_field.is_some()
|| interval.leading_precision.is_some()
|| interval.fractional_seconds_precision.is_some()
{
polars_bail!(SQLInterface: "interval with explicit leading field or precision is not supported: {:?}", interval)
}
let mut negative = false;
let s = match &*interval.value {
SQLExpr::UnaryOp {
op: UnaryOperator::Minus,
expr,
} if matches!(**expr, SQLExpr::Value(SQLValue::SingleQuotedString(_))) => {
if let SQLExpr::Value(SQLValue::SingleQuotedString(ref s)) = **expr {
negative = true;
Some(s)
} else {
unreachable!()
}
},
SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s),
_ => None,
};
match s {
Some(s) if s.contains('-') => {
polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
},
Some(s) => {
let d = Duration::parse_interval(s);
Ok(lit(if negative { -d } else { d }))
},
None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
}
}

fn visit_like(
&mut self,
negated: bool,
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_core::prelude::*;
use polars_lazy::prelude::*;
use polars_sql::*;
use polars_time::Duration;

fn create_sample_df() -> PolarsResult<DataFrame> {
let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::<Vec<_>>());
Expand Down Expand Up @@ -172,7 +173,8 @@ fn test_literal_exprs() {
1.0 as float_lit,
'foo' as string_lit,
true as bool_lit,
null as null_lit
null as null_lit,
interval '1 quarter 2 weeks 1 day 50 seconds' as duration_lit
FROM df"#;
let df_sql = context.execute(sql).unwrap().collect().unwrap();
let df_pl = df
Expand All @@ -183,6 +185,7 @@ fn test_literal_exprs() {
lit("foo").alias("string_lit"),
lit(true).alias("bool_lit"),
lit(NULL).alias("null_lit"),
lit(Duration::parse("1q2w1d50s")).alias("duration_lit"),
])
.collect()
.unwrap();
Expand Down
137 changes: 104 additions & 33 deletions crates/polars-time/src/windows/duration.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::ops::Mul;
use std::ops::{Mul, Neg};

#[cfg(feature = "timezones")]
use arrow::legacy::kernels::{Ambiguous, NonExistent};
Expand Down Expand Up @@ -56,6 +56,21 @@ impl Ord for Duration {
}
}

impl Neg for Duration {
type Output = Self;

fn neg(self) -> Self::Output {
Self {
months: self.months,
weeks: self.weeks,
days: self.days,
nsecs: self.nsecs,
negative: !self.negative,
parsed_int: self.parsed_int,
}
}
}

impl Display for Duration {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.is_zero() {
Expand Down Expand Up @@ -138,58 +153,93 @@ impl Duration {
/// # Panics
/// If the given str is invalid for any reason.
pub fn parse(duration: &str) -> Self {
let num_minus_signs = duration.matches('-').count();
Self::_parse(duration, false)
}

#[doc(hidden)]
/// Parse SQL-style "interval" string to Duration. Handles verbose
/// units (such as 'year', 'minutes', etc.) and whitespace, as
/// well as being case-insensitive.
pub fn parse_interval(interval: &str) -> Self {
Self::_parse(&interval.to_ascii_lowercase(), true)
}

fn _parse(s: &str, as_interval: bool) -> Self {
let s = if as_interval { s.trim_start() } else { s };

let parse_type = if as_interval { "interval" } else { "duration" };
let num_minus_signs = s.matches('-').count();
if num_minus_signs > 1 {
panic!("a Duration string can only have a single minus sign")
panic!("{} string can only have a single minus sign", parse_type)
}
if (num_minus_signs > 0) & !duration.starts_with('-') {
panic!("only a single minus sign is allowed, at the front of the string")
if num_minus_signs > 0 {
if as_interval {
// TODO: intervals need to support per-element minus signs
panic!("minus signs are not currently supported in interval strings")
} else if !s.starts_with('-') {
panic!("only a single minus sign is allowed, at the front of the string")
}
}

let mut nsecs = 0;
let mut months = 0;
let mut weeks = 0;
let mut days = 0;
let mut months = 0;
let negative = duration.starts_with('-');
let mut iter = duration.char_indices();
let mut nsecs = 0;

let negative = s.starts_with('-');
let mut iter = s.char_indices().peekable();
let mut start = 0;

// skip the '-' char
if negative {
start += 1;
iter.next().unwrap();
}

// permissive whitespace for intervals
if as_interval {
while let Some((i, ch)) = iter.peek() {
if *ch == ' ' {
start = *i + 1;
iter.next();
} else {
break;
}
}
}
// reserve capacity for the longest valid unit ("microseconds")
let mut unit = String::with_capacity(12);
let mut parsed_int = false;

let mut unit = String::with_capacity(2);
while let Some((i, mut ch)) = iter.next() {
if !ch.is_ascii_digit() {
let n = duration[start..i]
.parse::<i64>()
.expect("expected an integer in the duration string");
let n = s[start..i].parse::<i64>().unwrap_or_else(|_| {
panic!(
"expected leading integer in the {} string, found {}",
parse_type, ch
)
});

loop {
if ch.is_ascii_alphabetic() {
unit.push(ch)
} else {
break;
match ch {
c if c.is_ascii_alphabetic() => unit.push(c),
' ' | ',' if as_interval => {},
_ => break,
}
match iter.next() {
Some((i, ch_)) => {
ch = ch_;
start = i
},
None => {
break;
},
None => break,
}
}
if unit.is_empty() {
panic!("expected a unit in the duration string")
panic!(
"expected a unit to follow integer in the {} string '{}'",
parse_type, s
)
}

match &*unit {
// matches that are allowed for both duration/interval
"ns" => nsecs += n,
"us" => nsecs += n * NS_MICROSECOND,
"ms" => nsecs += n * NS_MILLISECOND,
Expand All @@ -198,17 +248,38 @@ impl Duration {
"h" => nsecs += n * NS_HOUR,
"d" => days += n,
"w" => weeks += n,
"mo" => {
months += n
}
"mo" => months += n,
"q" => months += n * 3,
"y" => months += n * 12,
// we will read indexes as nanoseconds
"i" => {
nsecs += n;
parsed_int = true;
}
unit => panic!("unit: '{unit}' not supported. Available units are: 'ns', 'us', 'ms', 's', 'm', 'h', 'd', 'w', 'q', 'mo', 'y', 'i'"),
},
_ if as_interval => match &*unit {
// interval-only (verbose/sql) matches
"nanosec" | "nanosecs" | "nanosecond" | "nanoseconds" => nsecs += n,
"microsec" | "microsecs" | "microsecond" | "microseconds" => {
nsecs += n * NS_MICROSECOND
},
"millisec" | "millisecs" | "millisecond" | "milliseconds" => {
nsecs += n * NS_MILLISECOND
},
"sec" | "secs" | "second" | "seconds" => nsecs += n * NS_SECOND,
"min" | "mins" | "minute" | "minutes" => nsecs += n * NS_MINUTE,
"hour" | "hours" => nsecs += n * NS_HOUR,
"day" | "days" => days += n,
"week" | "weeks" => weeks += n,
"mon" | "mons" | "month" | "months" => months += n,
"quarter" | "quarters" => months += n * 3,
"year" | "years" => months += n * 12,
_ => {
let valid_units = "'year', 'month', 'quarter', 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond', 'nanosecond'";
panic!("unit: '{unit}' not supported; available units include: {} (and their plurals)", valid_units)
},
},
_ => {
panic!("unit: '{unit}' not supported; available units are: 'y', 'mo', 'q', 'w', 'd', 'h', 'm', 's', 'ms', 'us', 'ns'")
},
}
unit.clear();
}
Expand Down Expand Up @@ -954,7 +1025,7 @@ pub fn ensure_is_constant_duration(
time_zone: Option<&str>,
variable_name: &str,
) -> PolarsResult<()> {
polars_ensure!(duration.is_constant_duration(time_zone),
polars_ensure!(duration.is_constant_duration(time_zone),
InvalidOperation: "expected `{}` to be a constant duration \
(i.e. one independent of differing month durations or of daylight savings time), got {}.\n\
\n\
Expand All @@ -971,11 +1042,11 @@ pub fn ensure_duration_matches_data_type(
) -> PolarsResult<()> {
match data_type {
DataType::Int64 | DataType::UInt64 | DataType::Int32 | DataType::UInt32 => {
polars_ensure!(duration.parsed_int || duration.is_zero(),
polars_ensure!(duration.parsed_int || duration.is_zero(),
InvalidOperation: "`{}` duration must be a parsed integer (i.e. use '2i', not '2d') when working with a numeric column", variable_name);
},
DataType::Datetime(_, _) | DataType::Date | DataType::Duration(_) | DataType::Time => {
polars_ensure!(!duration.parsed_int,
polars_ensure!(!duration.parsed_int,
InvalidOperation: "`{}` duration may not be a parsed integer (i.e. use '2d', not '2i') when working with a temporal column", variable_name);
},
_ => {
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/unit/sql/test_literals.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from datetime import timedelta

import pytest

import polars as pl
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
from polars.testing import assert_frame_equal


def test_bit_hex_literals() -> None:
Expand Down Expand Up @@ -88,3 +91,36 @@ def test_bit_hex_membership() -> None:
):
dff = df.filter(pl.sql_expr(f"x IN ({values})"))
assert dff["y"].to_list() == [1, 4]


def test_intervals() -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
out = ctx.execute(
"""
SELECT
-- short form with/without spaces
INTERVAL '1w2h3m4s' AS i1,
INTERVAL '100ms 100us' AS i2,
-- long form with/without commas (case-insensitive)
INTERVAL '1 week, 2 hours, 3 minutes, 4 seconds' AS i3,
INTERVAL '1 Quarter 2 Months 987 Microseconds' AS i4,
FROM df
"""
)
expected = pl.DataFrame(
{
"i1": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],
"i2": [timedelta(microseconds=100100)],
"i3": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],
"i4": [timedelta(days=140, microseconds=987)],
},
).cast(pl.Duration("ns"))

assert_frame_equal(expected, out)

# TODO: negative intervals
with pytest.raises(
SQLInterfaceError,
match="minus signs are not yet supported in interval strings; found '-7d'",
):
ctx.execute("SELECT INTERVAL '-7d' AS one_week_ago FROM df")
Loading

0 comments on commit a3f4ad9

Please sign in to comment.