From 13d68aede9e4194b0576f662c0cb7bfafed07289 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 11 Jun 2024 10:18:13 +0400 Subject: [PATCH] feat: Support SQL `VALUES` clause and inline renaming of columns in CTE & derived table definitions (#16851) --- crates/polars-sql/Cargo.toml | 8 +-- crates/polars-sql/src/context.rs | 60 +++++++++++++++++-- .../tests/unit/sql/test_miscellaneous.py | 49 ++++++++++++++- py-polars/tests/unit/sql/test_numeric.py | 20 ++++--- 4 files changed, 119 insertions(+), 18 deletions(-) diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 7bd388f043d0..83dcd5b98a3c 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -10,7 +10,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" [dependencies] arrow = { workspace = true } -polars-core = { workspace = true } +polars-core = { workspace = true, features = ["rows"] } polars-error = { workspace = true } polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } polars-ops = { workspace = true } @@ -32,12 +32,12 @@ polars-core = { workspace = true, features = ["fmt"] } [features] default = [] nightly = [] -csv = ["polars-lazy/csv"] -ipc = ["polars-lazy/ipc"] -json = ["polars-lazy/json", "polars-plan/extract_jsonpath"] binary_encoding = ["polars-lazy/binary_encoding"] +csv = ["polars-lazy/csv"] diagonal_concat = ["polars-lazy/diagonal_concat"] dtype-decimal = ["polars-lazy/dtype-decimal"] +ipc = ["polars-lazy/ipc"] +json = ["polars-lazy/json", "polars-plan/extract_jsonpath"] list_eval = ["polars-lazy/list_eval"] parquet = ["polars-lazy/parquet"] semi_anti_join = ["polars-lazy/semi_anti_join"] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 4921805d6f67..1570b6e0c919 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1,5 +1,6 @@ use std::cell::RefCell; +use polars_core::frame::row::Row; use polars_core::prelude::*; use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; @@ -8,7 +9,7 @@ use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, - Value as SQLValue, WildcardAdditionalOptions, + Value as SQLValue, Values, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -275,6 +276,10 @@ impl SQLContext { SetExpr::SetOperation { op, .. } => { polars_bail!(SQLInterface: "'{}' operation not yet supported", op) }, + SetExpr::Values(Values { + explicit_row: _, + rows, + }) => self.process_values(rows), op => polars_bail!(SQLInterface: "'{}' operation not yet supported", op), } } @@ -315,6 +320,25 @@ impl SQLContext { } } + fn process_values(&mut self, values: &[Vec]) -> PolarsResult { + let frame_rows: Vec = values.iter().map(|row| { + let row_data: Result, _> = row.iter().map(|expr| { + let expr = parse_sql_expr(expr, self, None)?; + match expr { + Expr::Literal(value) => { + value.to_any_value() + .ok_or_else(|| polars_err!(SQLInterface: "invalid literal value: {:?}", value)) + .map(|av| av.into_static().unwrap()) + }, + _ => polars_bail!(SQLInterface: "VALUES clause expects literals; found {}", expr), + } + }).collect(); + row_data.map(Row::new) + }).collect::>()?; + + Ok(DataFrame::from_rows(frame_rows.as_ref())?.lazy()) + } + // EXPLAIN SELECT * FROM DF fn execute_explain(&mut self, stmt: &Statement) -> PolarsResult { match stmt { @@ -393,8 +417,9 @@ impl SQLContext { } for cte in &with.cte_tables { let cte_name = cte.alias.name.value.clone(); - let cte_lf = self.execute_query(&cte.query)?; - self.register_cte(&cte_name, cte_lf); + let mut lf = self.execute_query(&cte.query)?; + lf = self.rename_columns_from_table_alias(lf, &cte.alias)?; + self.register_cte(&cte_name, lf); } } Ok(()) @@ -764,7 +789,7 @@ impl SQLContext { name, alias, args, .. } => { if let Some(args) = args { - return self.execute_tbl_function(name, alias, args); + return self.execute_table_function(name, alias, args); } let tbl_name = name.0.first().unwrap().value.as_str(); if let Some(lf) = self.get_table_from_current_scope(tbl_name) { @@ -788,7 +813,8 @@ impl SQLContext { } => { polars_ensure!(!(*lateral), SQLInterface: "LATERAL not supported"); if let Some(alias) = alias { - let lf = self.execute_query_no_ctes(subquery)?; + let mut lf = self.execute_query_no_ctes(subquery)?; + lf = self.rename_columns_from_table_alias(lf, alias)?; self.table_map.insert(alias.name.value.clone(), lf.clone()); Ok((alias.name.value.clone(), lf)) } else { @@ -861,7 +887,7 @@ impl SQLContext { } } - fn execute_tbl_function( + fn execute_table_function( &mut self, name: &ObjectName, alias: &Option, @@ -1084,6 +1110,28 @@ impl SQLContext { _ => expr, }) } + + fn rename_columns_from_table_alias( + &mut self, + mut frame: LazyFrame, + alias: &TableAlias, + ) -> PolarsResult { + if alias.columns.is_empty() { + Ok(frame) + } else { + let schema = frame.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + if alias.columns.len() != schema.len() { + polars_bail!( + SQLSyntax: "number of columns ({}) in alias '{}' does not match the number of columns in the table/query ({})", + alias.columns.len(), alias.name.value, schema.len() + ) + } else { + let existing_columns: Vec<_> = schema.iter_names().collect(); + let new_columns: Vec<_> = alias.columns.iter().map(|c| c.value.clone()).collect(); + Ok(frame.rename(existing_columns, new_columns)) + } + } + } } impl SQLContext { diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 8fa62a65c4a9..acd51966292c 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -5,7 +5,7 @@ import pytest import polars as pl -from polars.exceptions import SQLInterfaceError +from polars.exceptions import SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal @@ -194,3 +194,50 @@ def test_sql_on_compatible_frame_types() -> None: # don't register all compatible objects with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"): pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp") + + +def test_nested_cte_column_aliasing() -> None: + # trace through nested CTEs with multiple levels of column & table aliasing + df = pl.sql( + """ + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y (m, n) AS ( + WITH z(c, d) AS (SELECT a, b FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT n, m FROM y + """, + eager=True, + ) + assert df.to_dict(as_series=False) == { + "n": [3, 9], + "m": [4, 8], + } + + +def test_invalid_derived_table_column_aliases() -> None: + values_query = "SELECT * FROM (VALUES (1,2), (3,4))" + + with pytest.raises( + SQLSyntaxError, + match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)", + ): + pl.sql(f"{values_query} AS tbl(a, b, c, d, e)") + + assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)] + + +def test_values_clause_table_registration() -> None: + with pl.SQLContext(frames=None, eager=True) as ctx: + # initially no tables are registered + assert ctx.tables() == [] + + # confirm that VALUES clause derived table is registered, post-query + res1 = ctx.execute("SELECT * FROM (VALUES (-1,1)) AS tbl(x, y)") + assert ctx.tables() == ["tbl"] + + # and confirm that we can select from it by the registered name + res2 = ctx.execute("SELECT x, y FROM tbl") + for res in (res1, res2): + assert res.to_dict(as_series=False) == {"x": [-1], "y": [1]} diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index f075949a3864..77d149ccd06c 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -14,14 +14,20 @@ def test_div() -> None: - df = pl.DataFrame( - { - "a": [20.5, None, 10.0, 5.0, 2.5], - "b": [6, 12, 24, None, 5], - }, - ) - res = df.sql("SELECT DIV(a, b) AS a_div_b, DIV(b, a) AS b_div_a FROM self") + res = pl.sql(""" + SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a + FROM ( + VALUES + ('a', 20.5, 6), + ('b', NULL, 12), + ('c', 10.0, 24), + ('d', 5.0, NULL), + ('e', 2.5, 5) + ) AS tbl(label, a, b) + """).collect() + assert res.to_dict(as_series=False) == { + "label": ["a", "b", "c", "d", "e"], "a_div_b": [3, None, 0, None, 0], "b_div_a": [0, None, 2, None, 2], }