From 54f9105d96ec6d0777fecf73e09bda4858bf37ff Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 4 Jun 2024 13:01:29 +0400 Subject: [PATCH] feat: Initial SQL support for `INTERVAL` strings --- Cargo.lock | 1 + crates/polars-plan/src/logical_plan/lit.rs | 11 ++ crates/polars-sql/Cargo.toml | 1 + crates/polars-sql/src/sql_expr.rs | 40 +++++- crates/polars-sql/tests/simple_exprs.rs | 5 +- crates/polars-time/src/windows/duration.rs | 137 ++++++++++++++++----- py-polars/tests/unit/sql/test_literals.py | 36 ++++++ py-polars/tests/unit/sql/test_numeric.py | 7 +- 8 files changed, 202 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1b7eeacd2d76..26c370a0078f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3203,6 +3203,7 @@ dependencies = [ "polars-lazy", "polars-ops", "polars-plan", + "polars-time", "rand", "serde", "serde_json", diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index 4da749ecffdd..1c2dedddafa4 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -410,6 +410,17 @@ impl Literal for ChronoDuration { } } +#[cfg(feature = "dtype-duration")] +impl Literal for Duration { + fn lit(self) -> Expr { + let ns = self.duration_ns(); + Expr::Literal(LiteralValue::Duration( + if self.negative() { -ns } else { ns }, + TimeUnit::Nanoseconds, + )) + } +} + #[cfg(feature = "dtype-datetime")] impl Literal for NaiveDate { fn lit(self) -> Expr { diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 1db0e88a116c..7bd388f043d0 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -15,6 +15,7 @@ polars-error = { workspace = true } polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } polars-ops = { workspace = true } polars-plan = { workspace = true } +polars-time = { workspace = true } hex = { workspace = true } once_cell = { workspace = true } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 9e1af4610203..7d777e48b242 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -7,6 +7,7 @@ use polars_lazy::prelude::*; use polars_ops::series::SeriesReshape; use polars_plan::prelude::typed_lit; use polars_plan::prelude::LiteralValue::Null; +use polars_time::Duration; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; @@ -17,7 +18,7 @@ use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - JoinConstraint, ObjectName, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo, + Interval, JoinConstraint, ObjectName, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; @@ -259,6 +260,7 @@ impl SQLExprVisitor<'_> { subquery, negated, } => self.visit_in_subquery(expr, subquery, *negated), + SQLExpr::Interval(interval) => self.visit_interval(interval), SQLExpr::IsDistinctFrom(e1, e2) => { Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?)) }, @@ -409,6 +411,42 @@ impl SQLExprVisitor<'_> { } } + fn visit_interval(&self, interval: &Interval) -> PolarsResult { + if interval.last_field.is_some() + || interval.leading_field.is_some() + || interval.leading_precision.is_some() + || interval.fractional_seconds_precision.is_some() + { + polars_bail!(SQLInterface: "interval with explicit leading field or precision is not supported: {:?}", interval) + } + let mut negative = false; + let s = match &*interval.value { + SQLExpr::UnaryOp { + op: UnaryOperator::Minus, + expr, + } if matches!(**expr, SQLExpr::Value(SQLValue::SingleQuotedString(_))) => { + if let SQLExpr::Value(SQLValue::SingleQuotedString(ref s)) = **expr { + negative = true; + Some(s) + } else { + unreachable!() + } + }, + SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s), + _ => None, + }; + match s { + Some(s) if s.contains('-') => { + polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s) + }, + Some(s) => { + let d = Duration::parse_interval(s); + Ok(lit(if negative { -d } else { d })) + }, + None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval), + } + } + fn visit_like( &mut self, negated: bool, diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 7f9aae7aeeaa..e50ef1234971 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -1,6 +1,7 @@ use polars_core::prelude::*; use polars_lazy::prelude::*; use polars_sql::*; +use polars_time::Duration; fn create_sample_df() -> PolarsResult { let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::>()); @@ -172,7 +173,8 @@ fn test_literal_exprs() { 1.0 as float_lit, 'foo' as string_lit, true as bool_lit, - null as null_lit + null as null_lit, + interval '1 quarter 2 weeks 1 day 50 seconds' as duration_lit FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df @@ -183,6 +185,7 @@ fn test_literal_exprs() { lit("foo").alias("string_lit"), lit(true).alias("bool_lit"), lit(NULL).alias("null_lit"), + lit(Duration::parse("1q2w1d50s")).alias("duration_lit"), ]) .collect() .unwrap(); diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index a38d0e9cc217..b96d4b16e0d3 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; use std::fmt::{Display, Formatter}; -use std::ops::Mul; +use std::ops::{Mul, Neg}; #[cfg(feature = "timezones")] use arrow::legacy::kernels::{Ambiguous, NonExistent}; @@ -56,6 +56,21 @@ impl Ord for Duration { } } +impl Neg for Duration { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { + months: self.months, + weeks: self.weeks, + days: self.days, + nsecs: self.nsecs, + negative: !self.negative, + parsed_int: self.parsed_int, + } + } +} + impl Display for Duration { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { if self.is_zero() { @@ -138,20 +153,40 @@ impl Duration { /// # Panics /// If the given str is invalid for any reason. pub fn parse(duration: &str) -> Self { - let num_minus_signs = duration.matches('-').count(); + Self::_parse(duration, false) + } + + #[doc(hidden)] + /// Parse SQL-style "interval" string to Duration. Handles verbose + /// units (such as 'year', 'minutes', etc.) and whitespace, as + /// well as being case-insensitive. + pub fn parse_interval(interval: &str) -> Self { + Self::_parse(&interval.to_ascii_lowercase(), true) + } + + fn _parse(s: &str, as_interval: bool) -> Self { + let s = if as_interval { s.trim_start() } else { s }; + + let parse_type = if as_interval { "interval" } else { "duration" }; + let num_minus_signs = s.matches('-').count(); if num_minus_signs > 1 { - panic!("a Duration string can only have a single minus sign") + panic!("{} string can only have a single minus sign", parse_type) } - if (num_minus_signs > 0) & !duration.starts_with('-') { - panic!("only a single minus sign is allowed, at the front of the string") + if num_minus_signs > 0 { + if as_interval { + // TODO: intervals need to support per-element minus signs + panic!("minus signs are not currently supported in interval strings") + } else if !s.starts_with('-') { + panic!("only a single minus sign is allowed, at the front of the string") + } } - - let mut nsecs = 0; + let mut months = 0; let mut weeks = 0; let mut days = 0; - let mut months = 0; - let negative = duration.starts_with('-'); - let mut iter = duration.char_indices(); + let mut nsecs = 0; + + let negative = s.starts_with('-'); + let mut iter = s.char_indices().peekable(); let mut start = 0; // skip the '-' char @@ -159,37 +194,52 @@ impl Duration { start += 1; iter.next().unwrap(); } - + // permissive whitespace for intervals + if as_interval { + while let Some((i, ch)) = iter.peek() { + if *ch == ' ' { + start = *i + 1; + iter.next(); + } else { + break; + } + } + } + // reserve capacity for the longest valid unit ("microseconds") + let mut unit = String::with_capacity(12); let mut parsed_int = false; - let mut unit = String::with_capacity(2); while let Some((i, mut ch)) = iter.next() { if !ch.is_ascii_digit() { - let n = duration[start..i] - .parse::() - .expect("expected an integer in the duration string"); + let n = s[start..i].parse::().unwrap_or_else(|_| { + panic!( + "expected leading integer in the {} string, found {}", + parse_type, ch + ) + }); loop { - if ch.is_ascii_alphabetic() { - unit.push(ch) - } else { - break; + match ch { + c if c.is_ascii_alphabetic() => unit.push(c), + ' ' | ',' if as_interval => {}, + _ => break, } match iter.next() { Some((i, ch_)) => { ch = ch_; start = i }, - None => { - break; - }, + None => break, } } if unit.is_empty() { - panic!("expected a unit in the duration string") + panic!( + "expected a unit to follow integer in the {} string '{}'", + parse_type, s + ) } - match &*unit { + // matches that are allowed for both duration/interval "ns" => nsecs += n, "us" => nsecs += n * NS_MICROSECOND, "ms" => nsecs += n * NS_MILLISECOND, @@ -198,17 +248,38 @@ impl Duration { "h" => nsecs += n * NS_HOUR, "d" => days += n, "w" => weeks += n, - "mo" => { - months += n - } + "mo" => months += n, "q" => months += n * 3, "y" => months += n * 12, - // we will read indexes as nanoseconds "i" => { nsecs += n; parsed_int = true; - } - unit => panic!("unit: '{unit}' not supported. Available units are: 'ns', 'us', 'ms', 's', 'm', 'h', 'd', 'w', 'q', 'mo', 'y', 'i'"), + }, + _ if as_interval => match &*unit { + // interval-only (verbose/sql) matches + "nanosec" | "nanosecs" | "nanosecond" | "nanoseconds" => nsecs += n, + "microsec" | "microsecs" | "microsecond" | "microseconds" => { + nsecs += n * NS_MICROSECOND + }, + "millisec" | "millisecs" | "millisecond" | "milliseconds" => { + nsecs += n * NS_MILLISECOND + }, + "sec" | "secs" | "second" | "seconds" => nsecs += n * NS_SECOND, + "min" | "mins" | "minute" | "minutes" => nsecs += n * NS_MINUTE, + "hour" | "hours" => nsecs += n * NS_HOUR, + "day" | "days" => days += n, + "week" | "weeks" => weeks += n, + "mon" | "mons" | "month" | "months" => months += n, + "quarter" | "quarters" => months += n * 3, + "year" | "years" => months += n * 12, + _ => { + let valid_units = "'year', 'month', 'quarter', 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond', 'nanosecond'"; + panic!("unit: '{unit}' not supported; available units include: {} (and their plurals)", valid_units) + }, + }, + _ => { + panic!("unit: '{unit}' not supported; available units are: 'y', 'mo', 'q', 'w', 'd', 'h', 'm', 's', 'ms', 'us', 'ns'") + }, } unit.clear(); } @@ -954,7 +1025,7 @@ pub fn ensure_is_constant_duration( time_zone: Option<&str>, variable_name: &str, ) -> PolarsResult<()> { - polars_ensure!(duration.is_constant_duration(time_zone), + polars_ensure!(duration.is_constant_duration(time_zone), InvalidOperation: "expected `{}` to be a constant duration \ (i.e. one independent of differing month durations or of daylight savings time), got {}.\n\ \n\ @@ -971,11 +1042,11 @@ pub fn ensure_duration_matches_data_type( ) -> PolarsResult<()> { match data_type { DataType::Int64 | DataType::UInt64 | DataType::Int32 | DataType::UInt32 => { - polars_ensure!(duration.parsed_int || duration.is_zero(), + polars_ensure!(duration.parsed_int || duration.is_zero(), InvalidOperation: "`{}` duration must be a parsed integer (i.e. use '2i', not '2d') when working with a numeric column", variable_name); }, DataType::Datetime(_, _) | DataType::Date | DataType::Duration(_) | DataType::Time => { - polars_ensure!(!duration.parsed_int, + polars_ensure!(!duration.parsed_int, InvalidOperation: "`{}` duration may not be a parsed integer (i.e. use '2d', not '2i') when working with a temporal column", variable_name); }, _ => { diff --git a/py-polars/tests/unit/sql/test_literals.py b/py-polars/tests/unit/sql/test_literals.py index 8a5b86f68331..6a7c0e0733b1 100644 --- a/py-polars/tests/unit/sql/test_literals.py +++ b/py-polars/tests/unit/sql/test_literals.py @@ -1,9 +1,12 @@ from __future__ import annotations +from datetime import timedelta + import pytest import polars as pl from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal def test_bit_hex_literals() -> None: @@ -88,3 +91,36 @@ def test_bit_hex_membership() -> None: ): dff = df.filter(pl.sql_expr(f"x IN ({values})")) assert dff["y"].to_list() == [1, 4] + + +def test_intervals() -> None: + with pl.SQLContext(df=None, eager=True) as ctx: + out = ctx.execute( + """ + SELECT + -- short form with/without spaces + INTERVAL '1w2h3m4s' AS i1, + INTERVAL '100ms 100us' AS i2, + -- long form with/without commas (case-insensitive) + INTERVAL '1 week, 2 hours, 3 minutes, 4 seconds' AS i3, + INTERVAL '1 Quarter 2 Months 987 Microseconds' AS i4, + FROM df + """ + ) + expected = pl.DataFrame( + { + "i1": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)], + "i2": [timedelta(microseconds=100100)], + "i3": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)], + "i4": [timedelta(days=140, microseconds=987)], + }, + ).cast(pl.Duration("ns")) + + assert_frame_equal(expected, out) + + # TODO: negative intervals + with pytest.raises( + SQLInterfaceError, + match="minus signs are not yet supported in interval strings; found '-7d'", + ): + ctx.execute("SELECT INTERVAL '-7d' AS one_week_ago FROM df") diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index 77ca5feaa275..c9aa4f4b00b8 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -14,7 +14,12 @@ def test_div() -> None: - df = pl.DataFrame({"a": [20.5, None, 10.0, 5.0, 2.5], "b": [6, 12, 24, None, 5]}) + df = pl.DataFrame( + { + "a": [20.5, None, 10.0, 5.0, 2.5], + "b": [6, 12, 24, None, 5], + } + ) res = df.sql("SELECT DIV(a, b) AS a_div_b, DIV(b, a) AS b_div_a FROM self") assert res.to_dict(as_series=False) == { "a_div_b": [3, None, 0, None, 0],