From 4d0e22e1f334215a77bb7c8d9f6d1a549fd0a244 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 10 Jun 2024 10:00:23 +0400 Subject: [PATCH] feat: Support SQL `VALUES` clause and naming columns in CTE & derived table definitions --- crates/polars-sql/src/context.rs | 60 +++++++++++++++++-- .../tests/unit/sql/test_miscellaneous.py | 34 ++++++++++- py-polars/tests/unit/sql/test_numeric.py | 20 ++++--- 3 files changed, 100 insertions(+), 14 deletions(-) diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 4921805d6f67c..1570b6e0c9193 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1,5 +1,6 @@ use std::cell::RefCell; +use polars_core::frame::row::Row; use polars_core::prelude::*; use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; @@ -8,7 +9,7 @@ use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, - Value as SQLValue, WildcardAdditionalOptions, + Value as SQLValue, Values, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -275,6 +276,10 @@ impl SQLContext { SetExpr::SetOperation { op, .. } => { polars_bail!(SQLInterface: "'{}' operation not yet supported", op) }, + SetExpr::Values(Values { + explicit_row: _, + rows, + }) => self.process_values(rows), op => polars_bail!(SQLInterface: "'{}' operation not yet supported", op), } } @@ -315,6 +320,25 @@ impl SQLContext { } } + fn process_values(&mut self, values: &[Vec]) -> PolarsResult { + let frame_rows: Vec = values.iter().map(|row| { + let row_data: Result, _> = row.iter().map(|expr| { + let expr = parse_sql_expr(expr, self, None)?; + match expr { + Expr::Literal(value) => { + value.to_any_value() + .ok_or_else(|| polars_err!(SQLInterface: "invalid literal value: {:?}", value)) + .map(|av| av.into_static().unwrap()) + }, + _ => polars_bail!(SQLInterface: "VALUES clause expects literals; found {}", expr), + } + }).collect(); + row_data.map(Row::new) + }).collect::>()?; + + Ok(DataFrame::from_rows(frame_rows.as_ref())?.lazy()) + } + // EXPLAIN SELECT * FROM DF fn execute_explain(&mut self, stmt: &Statement) -> PolarsResult { match stmt { @@ -393,8 +417,9 @@ impl SQLContext { } for cte in &with.cte_tables { let cte_name = cte.alias.name.value.clone(); - let cte_lf = self.execute_query(&cte.query)?; - self.register_cte(&cte_name, cte_lf); + let mut lf = self.execute_query(&cte.query)?; + lf = self.rename_columns_from_table_alias(lf, &cte.alias)?; + self.register_cte(&cte_name, lf); } } Ok(()) @@ -764,7 +789,7 @@ impl SQLContext { name, alias, args, .. } => { if let Some(args) = args { - return self.execute_tbl_function(name, alias, args); + return self.execute_table_function(name, alias, args); } let tbl_name = name.0.first().unwrap().value.as_str(); if let Some(lf) = self.get_table_from_current_scope(tbl_name) { @@ -788,7 +813,8 @@ impl SQLContext { } => { polars_ensure!(!(*lateral), SQLInterface: "LATERAL not supported"); if let Some(alias) = alias { - let lf = self.execute_query_no_ctes(subquery)?; + let mut lf = self.execute_query_no_ctes(subquery)?; + lf = self.rename_columns_from_table_alias(lf, alias)?; self.table_map.insert(alias.name.value.clone(), lf.clone()); Ok((alias.name.value.clone(), lf)) } else { @@ -861,7 +887,7 @@ impl SQLContext { } } - fn execute_tbl_function( + fn execute_table_function( &mut self, name: &ObjectName, alias: &Option, @@ -1084,6 +1110,28 @@ impl SQLContext { _ => expr, }) } + + fn rename_columns_from_table_alias( + &mut self, + mut frame: LazyFrame, + alias: &TableAlias, + ) -> PolarsResult { + if alias.columns.is_empty() { + Ok(frame) + } else { + let schema = frame.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + if alias.columns.len() != schema.len() { + polars_bail!( + SQLSyntax: "number of columns ({}) in alias '{}' does not match the number of columns in the table/query ({})", + alias.columns.len(), alias.name.value, schema.len() + ) + } else { + let existing_columns: Vec<_> = schema.iter_names().collect(); + let new_columns: Vec<_> = alias.columns.iter().map(|c| c.value.clone()).collect(); + Ok(frame.rename(existing_columns, new_columns)) + } + } + } } impl SQLContext { diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 8fa62a65c4a90..0ee8c332afa47 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -5,7 +5,7 @@ import pytest import polars as pl -from polars.exceptions import SQLInterfaceError +from polars.exceptions import SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal @@ -194,3 +194,35 @@ def test_sql_on_compatible_frame_types() -> None: # don't register all compatible objects with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"): pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp") + + +def test_nested_cte_column_aliasing() -> None: + # trace through nested CTEs with multiple levels of column & table aliasing + df = pl.sql( + """ + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y AS ( + WITH z(c, d) AS (SELECT * FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT c3, d2 FROM y + """, + eager=True, + ) + assert df.to_dict(as_series=False) == { + "c3": [3, 9], + "d2": [4, 8], + } + + +def test_invalid_derived_table_column_aliases() -> None: + values_query = "SELECT * FROM (VALUES (1,2), (3,4))" + + with pytest.raises( + SQLSyntaxError, + match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)", + ): + pl.sql(f"{values_query} AS tbl(a, b, c, d, e)") + + assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)] diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index f075949a38640..77d149ccd06cf 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -14,14 +14,20 @@ def test_div() -> None: - df = pl.DataFrame( - { - "a": [20.5, None, 10.0, 5.0, 2.5], - "b": [6, 12, 24, None, 5], - }, - ) - res = df.sql("SELECT DIV(a, b) AS a_div_b, DIV(b, a) AS b_div_a FROM self") + res = pl.sql(""" + SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a + FROM ( + VALUES + ('a', 20.5, 6), + ('b', NULL, 12), + ('c', 10.0, 24), + ('d', 5.0, NULL), + ('e', 2.5, 5) + ) AS tbl(label, a, b) + """).collect() + assert res.to_dict(as_series=False) == { + "label": ["a", "b", "c", "d", "e"], "a_div_b": [3, None, 0, None, 0], "b_div_a": [0, None, 2, None, 2], }