Skip to content

Commit

Permalink
add struct.* wildcard selection support
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jun 21, 2024
1 parent 3bb8050 commit 239ddde
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 43 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {

/// In case of single col(*) -> do nothing, no selection is the same as select all
/// In other cases replace the wildcard with an expression with all columns
pub(crate) fn rewrite_projections(
pub fn rewrite_projections(
exprs: Vec<Expr>,
schema: &Schema,
keys: &[Expr],
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod convert_utils;
mod dsl_to_ir;
mod expr_expansion;
pub(crate) mod expr_expansion;
mod expr_to_ir;
mod ir_to_dsl;
#[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub(crate) mod ir;
mod apply;
mod builder_dsl;
mod builder_ir;
pub(crate) mod conversion;
pub mod conversion;
#[cfg(feature = "debugging")]
pub(crate) mod debug;
pub mod expr_ir;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub(crate) use polars_time::prelude::*;
pub use polars_utils::arena::{Arena, Node};

pub use crate::dsl::*;
pub use crate::plans::conversion::expr_expansion::rewrite_projections;
#[cfg(feature = "debugging")]
pub use crate::plans::debug::*;
pub use crate::plans::options::*;
Expand Down
38 changes: 15 additions & 23 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use polars_lazy::prelude::*;
use polars_ops::frame::JoinCoalesce;
use polars_plan::prelude::*;
use sqlparser::ast::{
Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint,
Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint,
JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr,
SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator,
Value as SQLValue, Values, WildcardAdditionalOptions,
Expand Down Expand Up @@ -600,7 +600,7 @@ impl SQLContext {
lf = self.process_where(lf, &select_stmt.selection)?;

// Column projections.
let projections: Vec<_> = select_stmt
let mut projections: Vec<Expr> = select_stmt
.projection
.iter()
.map(|select_item| {
Expand All @@ -610,11 +610,12 @@ impl SQLContext {
let expr = parse_sql_expr(expr, self, schema.as_deref())?;
expr.alias(&alias.value)
},
SelectItem::QualifiedWildcard(oname, wildcard_options) => self
SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self
.process_qualified_wildcard(
oname,
obj_name,
wildcard_options,
&mut contains_wildcard_exclude,
schema.as_deref(),
)?,
SelectItem::Wildcard(wildcard_options) => {
contains_wildcard = true;
Expand All @@ -629,7 +630,10 @@ impl SQLContext {
})
.collect::<PolarsResult<_>>()?;

// Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints).
// expand/flatten projections, so we have distinct expressions/columns
projections = rewrite_projections(projections, &(schema.clone().unwrap()), &[])?;

// Check for "GROUP BY ..." (after determining projections)
let mut group_by_keys: Vec<Expr> = Vec::new();
match &select_stmt.group_by {
// Standard "GROUP BY x, y, z" syntax (also recognising ordinal values)
Expand Down Expand Up @@ -1152,25 +1156,13 @@ impl SQLContext {
ObjectName(idents): &ObjectName,
options: &WildcardAdditionalOptions,
contains_wildcard_exclude: &mut bool,
schema: Option<&Schema>,
) -> PolarsResult<Expr> {
let idents = idents.as_slice();
let e = match idents {
[tbl_name] => {
let lf = self.table_map.get_mut(&tbl_name.value).ok_or_else(|| {
polars_err!(
SQLInterface: "no table named '{}' found",
tbl_name
)
})?;
let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?;
cols(schema.iter_names())
},
e => polars_bail!(
SQLSyntax: "invalid wildcard expression ({:?})",
e
),
};
self.process_wildcard_additional_options(e, options, contains_wildcard_exclude)
let mut new_idents = idents.clone();
new_idents.push(Ident::new("*"));
let identifier = SQLExpr::CompoundIdentifier(new_idents);
let expr = parse_sql_expr(&identifier, self, schema)?;
self.process_wildcard_additional_options(expr, options, contains_wildcard_exclude)
}

fn process_wildcard_additional_options(
Expand Down
17 changes: 11 additions & 6 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,23 +390,28 @@ impl SQLExprVisitor<'_> {
} else {
Schema::new()
}))
};
}?;

let mut column: PolarsResult<Expr> = if lf.is_none() {
let mut column: PolarsResult<Expr> = if lf.is_none() && schema.is_empty() {
Ok(col(&ident_root.value))
} else {
let col_name = &remaining_idents.next().unwrap().value;
if let Some((_, name, _)) = schema?.get_full(col_name) {
let resolved = &self.ctx.resolve_name(&ident_root.value, col_name);
let name = &remaining_idents.next().unwrap().value;
if lf.is_some() && name == "*" {
Ok(cols(schema.iter_names()))
} else if let Some((_, name, _)) = schema.get_full(name) {
let resolved = &self.ctx.resolve_name(&ident_root.value, name);
Ok(if name != resolved {
col(resolved).alias(name)
} else {
col(name)
})
} else if lf.is_none() {
remaining_idents = idents.iter().skip(1);
Ok(col(&ident_root.value))
} else {
polars_bail!(
SQLInterface: "no column named '{}' found in table '{}'",
col_name,
name,
ident_root
)
}
Expand Down
51 changes: 51 additions & 0 deletions crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,23 @@ fn test_ctes() -> PolarsResult<()> {
Ok(())
}

#[test]
fn test_cte_values() -> PolarsResult<()> {
let sql = r#"
WITH
x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)),
y (m, n) AS (
WITH z(c, d) AS (SELECT a, b FROM x)
SELECT d*2 AS d2, c*3 AS c3 FROM z
)
SELECT n, m FROM y
"#;
let mut context = SQLContext::new();
assert!(context.execute(sql).is_ok());

Ok(())
}

#[test]
#[cfg(feature = "ipc")]
fn test_group_by_2() -> PolarsResult<()> {
Expand Down Expand Up @@ -566,6 +583,7 @@ fn test_group_by_2() -> PolarsResult<()> {
SortMultipleOptions::default().with_order_descending_multi([false, true]),
)
.limit(2);

let expected = expected.collect()?;
assert!(df_sql.equals(&expected));
Ok(())
Expand All @@ -591,6 +609,7 @@ fn test_case_expr() {
.then(lit("lteq_5"))
.otherwise(lit("no match"))
.alias("sign");

let df_pl = df.lazy().select(&[case_expr]).collect().unwrap();
assert!(df_sql.equals(&df_pl));
}
Expand All @@ -600,6 +619,7 @@ fn test_case_expr_with_expression() {
let df = create_sample_df().unwrap();
let mut context = SQLContext::new();
context.register("df", df.clone().lazy());

let sql = r#"
SELECT
CASE b%2
Expand All @@ -615,6 +635,7 @@ fn test_case_expr_with_expression() {
.then(lit("odd"))
.otherwise(lit("No?"))
.alias("parity");

let df_pl = df.lazy().select(&[case_expr]).collect().unwrap();
assert!(df_sql.equals(&df_pl));
}
Expand All @@ -639,8 +660,38 @@ fn test_iss_9471() {
}
.unwrap()
.lazy();

let mut context = SQLContext::new();
context.register("df", df);
let res = context.execute(sql);
assert!(res.is_err())
}

#[test]
fn test_struct_wildcards() {
let struct_cols = vec![col("num"), col("str"), col("val")];
let df_original = df! {
"num" => [100, 200, 300, 400],
"str" => ["d", "c", "b", "a"],
"val" => [0.0, 5.0, 3.0, 4.0],
}
.unwrap();

let df_struct = df_original
.clone()
.lazy()
.select([as_struct(struct_cols).alias("json_msg")]);

let mut context = SQLContext::new();
context.register("df", df_struct.clone().lazy());

for sql in [
r#"SELECT json_msg.* FROM df"#,
r#"SELECT df.json_msg.* FROM df"#,
r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#,
r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.str DESC"#,
] {
let df_sql = context.execute(sql).unwrap().collect().unwrap();
assert!(df_sql.equals(&df_original));
}
}
44 changes: 33 additions & 11 deletions py-polars/tests/unit/sql/test_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.fixture()
def struct_df() -> pl.DataFrame:
def df_struct() -> pl.DataFrame:
return pl.DataFrame(
{
"id": [100, 200, 300, 400],
Expand All @@ -19,8 +19,8 @@ def struct_df() -> pl.DataFrame:
).select(pl.struct(pl.all()).alias("json_msg"))


def test_struct_field_selection(struct_df: pl.DataFrame) -> None:
res = struct_df.sql(
def test_struct_field_selection(df_struct: pl.DataFrame) -> None:
res = df_struct.sql(
"""
SELECT
-- validate table alias resolution
Expand All @@ -36,17 +36,39 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None:
json_msg.name DESC
"""
)

expected = pl.DataFrame(
{
"ID": [400, 100],
"NAME": ["Zoe", "Alice"],
"AGE": [45, 32],
}
{"ID": [400, 100], "NAME": ["Zoe", "Alice"], "AGE": [45, 32]}
)
assert_frame_equal(expected, res)


@pytest.mark.parametrize(
("fields", "excluding"),
[
("json_msg.*", ""),
("self.json_msg.*", ""),
("json_msg.other.*", ""),
("self.json_msg.other.*", ""),
],
)
def test_struct_field_wildcard_selection(
fields: str,
excluding: str,
df_struct: pl.DataFrame,
) -> None:
query = f"SELECT {fields} {excluding} FROM df_struct ORDER BY json_msg.id"
print(query)
res = pl.sql(query).collect()

expected = df_struct.unnest("json_msg")
if fields.endswith(".other.*"):
expected = expected["other"].struct.unnest()
if excluding:
expected = expected.drop(excluding.split(","))

assert_frame_equal(expected, res)


@pytest.mark.parametrize(
"invalid_column",
[
Expand All @@ -55,6 +77,6 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None:
"self.json_msg.other.invalid_column",
],
)
def test_struct_indexing_errors(invalid_column: str, struct_df: pl.DataFrame) -> None:
def test_struct_indexing_errors(invalid_column: str, df_struct: pl.DataFrame) -> None:
with pytest.raises(StructFieldNotFoundError, match="invalid_column"):
struct_df.sql(f"SELECT {invalid_column} FROM self")
df_struct.sql(f"SELECT {invalid_column} FROM self")

0 comments on commit 239ddde

Please sign in to comment.