From 239ddde387118931d5650f9c8b08fc2152aedbb8 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 22 Jun 2024 00:01:40 +0400 Subject: [PATCH] add `struct.*` wildcard selection support --- .../src/plans/conversion/expr_expansion.rs | 2 +- .../polars-plan/src/plans/conversion/mod.rs | 2 +- crates/polars-plan/src/plans/mod.rs | 2 +- crates/polars-plan/src/prelude.rs | 1 + crates/polars-sql/src/context.rs | 38 ++++++-------- crates/polars-sql/src/sql_expr.rs | 17 ++++--- crates/polars-sql/tests/simple_exprs.rs | 51 +++++++++++++++++++ py-polars/tests/unit/sql/test_structs.py | 44 ++++++++++++---- 8 files changed, 114 insertions(+), 43 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index 4f168d1080d15..fee7913dede7d 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -634,7 +634,7 @@ fn find_flags(expr: &Expr) -> PolarsResult { /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns -pub(crate) fn rewrite_projections( +pub fn rewrite_projections( exprs: Vec, schema: &Schema, keys: &[Expr], diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index d0d6a41e9fb96..afdac2d300fc8 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -1,6 +1,6 @@ mod convert_utils; mod dsl_to_ir; -mod expr_expansion; +pub(crate) mod expr_expansion; mod expr_to_ir; mod ir_to_dsl; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index ca9acc44cf532..9255c811e4893 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -16,7 +16,7 @@ pub(crate) mod ir; mod apply; mod builder_dsl; mod builder_ir; -pub(crate) mod conversion; +pub mod conversion; #[cfg(feature = "debugging")] pub(crate) mod debug; pub mod expr_ir; diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index d90e032cc925b..34c38cefbdabc 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -11,6 +11,7 @@ pub(crate) use polars_time::prelude::*; pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; +pub use crate::plans::conversion::expr_expansion::rewrite_projections; #[cfg(feature = "debugging")] pub use crate::plans::debug::*; pub use crate::plans::options::*; diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 5e44f19563325..2ed4ea8849079 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -6,7 +6,7 @@ use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; use polars_plan::prelude::*; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, + Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, @@ -600,7 +600,7 @@ impl SQLContext { lf = self.process_where(lf, &select_stmt.selection)?; // Column projections. - let projections: Vec<_> = select_stmt + let mut projections: Vec = select_stmt .projection .iter() .map(|select_item| { @@ -610,11 +610,12 @@ impl SQLContext { let expr = parse_sql_expr(expr, self, schema.as_deref())?; expr.alias(&alias.value) }, - SelectItem::QualifiedWildcard(oname, wildcard_options) => self + SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self .process_qualified_wildcard( - oname, + obj_name, wildcard_options, &mut contains_wildcard_exclude, + schema.as_deref(), )?, SelectItem::Wildcard(wildcard_options) => { contains_wildcard = true; @@ -629,7 +630,10 @@ impl SQLContext { }) .collect::>()?; - // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). + // expand/flatten projections, so we have distinct expressions/columns + projections = rewrite_projections(projections, &(schema.clone().unwrap()), &[])?; + + // Check for "GROUP BY ..." (after determining projections) let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { // Standard "GROUP BY x, y, z" syntax (also recognising ordinal values) @@ -1152,25 +1156,13 @@ impl SQLContext { ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, contains_wildcard_exclude: &mut bool, + schema: Option<&Schema>, ) -> PolarsResult { - let idents = idents.as_slice(); - let e = match idents { - [tbl_name] => { - let lf = self.table_map.get_mut(&tbl_name.value).ok_or_else(|| { - polars_err!( - SQLInterface: "no table named '{}' found", - tbl_name - ) - })?; - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - cols(schema.iter_names()) - }, - e => polars_bail!( - SQLSyntax: "invalid wildcard expression ({:?})", - e - ), - }; - self.process_wildcard_additional_options(e, options, contains_wildcard_exclude) + let mut new_idents = idents.clone(); + new_idents.push(Ident::new("*")); + let identifier = SQLExpr::CompoundIdentifier(new_idents); + let expr = parse_sql_expr(&identifier, self, schema)?; + self.process_wildcard_additional_options(expr, options, contains_wildcard_exclude) } fn process_wildcard_additional_options( diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 39a9eca5c8183..a4ae10997715e 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -390,23 +390,28 @@ impl SQLExprVisitor<'_> { } else { Schema::new() })) - }; + }?; - let mut column: PolarsResult = if lf.is_none() { + let mut column: PolarsResult = if lf.is_none() && schema.is_empty() { Ok(col(&ident_root.value)) } else { - let col_name = &remaining_idents.next().unwrap().value; - if let Some((_, name, _)) = schema?.get_full(col_name) { - let resolved = &self.ctx.resolve_name(&ident_root.value, col_name); + let name = &remaining_idents.next().unwrap().value; + if lf.is_some() && name == "*" { + Ok(cols(schema.iter_names())) + } else if let Some((_, name, _)) = schema.get_full(name) { + let resolved = &self.ctx.resolve_name(&ident_root.value, name); Ok(if name != resolved { col(resolved).alias(name) } else { col(name) }) + } else if lf.is_none() { + remaining_idents = idents.iter().skip(1); + Ok(col(&ident_root.value)) } else { polars_bail!( SQLInterface: "no column named '{}' found in table '{}'", - col_name, + name, ident_root ) } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 0a60b9dc7acab..262b4f6622bbd 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -530,6 +530,23 @@ fn test_ctes() -> PolarsResult<()> { Ok(()) } +#[test] +fn test_cte_values() -> PolarsResult<()> { + let sql = r#" + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y (m, n) AS ( + WITH z(c, d) AS (SELECT a, b FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT n, m FROM y + "#; + let mut context = SQLContext::new(); + assert!(context.execute(sql).is_ok()); + + Ok(()) +} + #[test] #[cfg(feature = "ipc")] fn test_group_by_2() -> PolarsResult<()> { @@ -566,6 +583,7 @@ fn test_group_by_2() -> PolarsResult<()> { SortMultipleOptions::default().with_order_descending_multi([false, true]), ) .limit(2); + let expected = expected.collect()?; assert!(df_sql.equals(&expected)); Ok(()) @@ -591,6 +609,7 @@ fn test_case_expr() { .then(lit("lteq_5")) .otherwise(lit("no match")) .alias("sign"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } @@ -600,6 +619,7 @@ fn test_case_expr_with_expression() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" SELECT CASE b%2 @@ -615,6 +635,7 @@ fn test_case_expr_with_expression() { .then(lit("odd")) .otherwise(lit("No?")) .alias("parity"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } @@ -639,8 +660,38 @@ fn test_iss_9471() { } .unwrap() .lazy(); + let mut context = SQLContext::new(); context.register("df", df); let res = context.execute(sql); assert!(res.is_err()) } + +#[test] +fn test_struct_wildcards() { + let struct_cols = vec![col("num"), col("str"), col("val")]; + let df_original = df! { + "num" => [100, 200, 300, 400], + "str" => ["d", "c", "b", "a"], + "val" => [0.0, 5.0, 3.0, 4.0], + } + .unwrap(); + + let df_struct = df_original + .clone() + .lazy() + .select([as_struct(struct_cols).alias("json_msg")]); + + let mut context = SQLContext::new(); + context.register("df", df_struct.clone().lazy()); + + for sql in [ + r#"SELECT json_msg.* FROM df"#, + r#"SELECT df.json_msg.* FROM df"#, + r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#, + r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.str DESC"#, + ] { + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + assert!(df_sql.equals(&df_original)); + } +} diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py index 6f1cad494ac6e..9ed6bd1a2cb03 100644 --- a/py-polars/tests/unit/sql/test_structs.py +++ b/py-polars/tests/unit/sql/test_structs.py @@ -8,7 +8,7 @@ @pytest.fixture() -def struct_df() -> pl.DataFrame: +def df_struct() -> pl.DataFrame: return pl.DataFrame( { "id": [100, 200, 300, 400], @@ -19,8 +19,8 @@ def struct_df() -> pl.DataFrame: ).select(pl.struct(pl.all()).alias("json_msg")) -def test_struct_field_selection(struct_df: pl.DataFrame) -> None: - res = struct_df.sql( +def test_struct_field_selection(df_struct: pl.DataFrame) -> None: + res = df_struct.sql( """ SELECT -- validate table alias resolution @@ -36,17 +36,39 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None: json_msg.name DESC """ ) - expected = pl.DataFrame( - { - "ID": [400, 100], - "NAME": ["Zoe", "Alice"], - "AGE": [45, 32], - } + {"ID": [400, 100], "NAME": ["Zoe", "Alice"], "AGE": [45, 32]} ) assert_frame_equal(expected, res) +@pytest.mark.parametrize( + ("fields", "excluding"), + [ + ("json_msg.*", ""), + ("self.json_msg.*", ""), + ("json_msg.other.*", ""), + ("self.json_msg.other.*", ""), + ], +) +def test_struct_field_wildcard_selection( + fields: str, + excluding: str, + df_struct: pl.DataFrame, +) -> None: + query = f"SELECT {fields} {excluding} FROM df_struct ORDER BY json_msg.id" + print(query) + res = pl.sql(query).collect() + + expected = df_struct.unnest("json_msg") + if fields.endswith(".other.*"): + expected = expected["other"].struct.unnest() + if excluding: + expected = expected.drop(excluding.split(",")) + + assert_frame_equal(expected, res) + + @pytest.mark.parametrize( "invalid_column", [ @@ -55,6 +77,6 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None: "self.json_msg.other.invalid_column", ], ) -def test_struct_indexing_errors(invalid_column: str, struct_df: pl.DataFrame) -> None: +def test_struct_indexing_errors(invalid_column: str, df_struct: pl.DataFrame) -> None: with pytest.raises(StructFieldNotFoundError, match="invalid_column"): - struct_df.sql(f"SELECT {invalid_column} FROM self") + df_struct.sql(f"SELECT {invalid_column} FROM self")