Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Scalar checks #18627

Merged
merged 11 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/polars-expr/src/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,10 @@ impl PhysicalExpr for AggregationExpr {
}
}

fn is_scalar(&self) -> bool {
true
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down Expand Up @@ -742,6 +746,10 @@ impl PhysicalExpr for AggQuantileExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
true
}
}

/// Simple wrapper to parallelize functions that can be divided over threads aggregated and
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ impl PhysicalExpr for AliasExpr {
))
}

fn is_scalar(&self) -> bool {
self.physical_expr.is_scalar()
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down
18 changes: 13 additions & 5 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pub struct ApplyExpr {
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
returns_scalar: bool,
function_returns_scalar: bool,
function_operates_on_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
Expand All @@ -29,6 +30,7 @@ pub struct ApplyExpr {
}

impl ApplyExpr {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
Expand All @@ -37,6 +39,7 @@ impl ApplyExpr {
allow_threading: bool,
input_schema: Option<SchemaRef>,
output_dtype: Option<DataType>,
returns_scalar: bool,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise)
Expand All @@ -50,7 +53,8 @@ impl ApplyExpr {
function,
expr,
collect_groups: options.collect_groups,
returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
function_returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
function_operates_on_scalar: returns_scalar,
allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME),
pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY),
input_schema,
Expand All @@ -72,7 +76,8 @@ impl ApplyExpr {
function,
expr,
collect_groups,
returns_scalar: false,
function_returns_scalar: false,
function_operates_on_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
Expand Down Expand Up @@ -104,7 +109,7 @@ impl ApplyExpr {
ca: ListChunked,
) -> PolarsResult<AggregationContext<'a>> {
let all_unit_len = all_unit_length(&ca);
if all_unit_len && self.returns_scalar {
if all_unit_len && self.function_returns_scalar {
ac.with_agg_state(AggState::AggregatedScalar(
ca.explode().unwrap().into_series(),
));
Expand Down Expand Up @@ -253,7 +258,7 @@ impl ApplyExpr {
let mut ac = acs.swap_remove(0);
ac.with_update_groups(UpdateGroups::No);

let agg_state = if self.returns_scalar {
let agg_state = if self.function_returns_scalar {
AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype))
} else {
match self.collect_groups {
Expand Down Expand Up @@ -426,6 +431,9 @@ impl PhysicalExpr for ApplyExpr {
None
}
}
fn is_scalar(&self) -> bool {
self.function_returns_scalar || self.function_operates_on_scalar
}
}

fn apply_multiple_elementwise<'a>(
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct BinaryExpr {
expr: Expr,
has_literal: bool,
allow_threading: bool,
is_scalar: bool,
}

impl BinaryExpr {
Expand All @@ -25,6 +26,7 @@ impl BinaryExpr {
expr: Expr,
has_literal: bool,
allow_threading: bool,
is_scalar: bool,
) -> Self {
Self {
left,
Expand All @@ -33,6 +35,7 @@ impl BinaryExpr {
expr,
has_literal,
allow_threading,
is_scalar,
}
}
}
Expand Down Expand Up @@ -254,6 +257,10 @@ impl PhysicalExpr for BinaryExpr {
self.expr.to_field(input_schema, Context::Default)
}

fn is_scalar(&self) -> bool {
self.is_scalar
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ impl PhysicalExpr for CastExpr {
})
}

fn is_scalar(&self) -> bool {
self.input.is_scalar()
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-expr/src/expressions/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ impl PhysicalExpr for ColumnExpr {
)
})
}
fn is_scalar(&self) -> bool {
false
}
}

impl PartitionedAggregation for ColumnExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl PhysicalExpr for CountExpr {
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

fn is_scalar(&self) -> bool {
true
}
}

impl PartitionedAggregation for CountExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,8 @@ impl PhysicalExpr for FilterExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl PhysicalExpr for GatherExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.phys_expr.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
self.returns_scalar
}
}

impl GatherExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl PhysicalExpr for LiteralExpr {
fn is_literal(&self) -> bool {
true
}

fn is_scalar(&self) -> bool {
self.0.is_scalar()
}
}

impl PartitionedAggregation for LiteralExpr {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ pub trait PhysicalExpr: Send + Sync {
fn is_literal(&self) -> bool {
false
}
fn is_scalar(&self) -> bool;
}

impl Display for &dyn PhysicalExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/rolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,8 @@ impl PhysicalExpr for RollingExpr {
fn as_expression(&self) -> Option<&Expr> {
Some(&self.expr)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,8 @@ impl PhysicalExpr for SliceExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,8 @@ impl PhysicalExpr for SortExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.physical_expr.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,8 @@ impl PhysicalExpr for SortByExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
7 changes: 7 additions & 0 deletions crates/polars-expr/src/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct TernaryExpr {
expr: Expr,
// Can be expensive on small data to run literals in parallel.
run_par: bool,
returns_scalar: bool,
}

impl TernaryExpr {
Expand All @@ -21,13 +22,15 @@ impl TernaryExpr {
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
run_par: bool,
returns_scalar: bool,
) -> Self {
Self {
predicate,
truthy,
falsy,
expr,
run_par,
returns_scalar,
}
}
}
Expand Down Expand Up @@ -322,6 +325,10 @@ impl PhysicalExpr for TernaryExpr {
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

fn is_scalar(&self) -> bool {
self.returns_scalar
}
}

impl PartitionedAggregation for TernaryExpr {
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-expr/src/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ impl PhysicalExpr for WindowExpr {
match self.determine_map_strategy(ac.agg_state(), sorted_keys, &gb)? {
Nothing => {
let mut out = ac.flat_naive().into_owned();

if ac.is_literal() {
out = out.new_from_index(0, df.height())
}
cache_gb(gb, state, &cache_key);
if let Some(name) = &self.out_name {
out.rename(name.clone());
Expand Down Expand Up @@ -630,6 +634,10 @@ impl PhysicalExpr for WindowExpr {
self.function.to_field(input_schema, Context::Default)
}

fn is_scalar(&self) -> bool {
false
}

#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ fn create_physical_expr_inner(
)))
},
BinaryExpr { left, op, right } => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let lhs = create_physical_expr_inner(*left, ctxt, expr_arena, schema, state)?;
let rhs = create_physical_expr_inner(*right, ctxt, expr_arena, schema, state)?;
Ok(Arc::new(phys_expr::BinaryExpr::new(
Expand All @@ -302,6 +303,7 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
state.local.has_lit,
state.allow_threading,
is_scalar,
)))
},
Column(column) => Ok(Arc::new(ColumnExpr::new(
Expand Down Expand Up @@ -444,6 +446,7 @@ fn create_physical_expr_inner(
truthy,
falsy,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let mut lit_count = 0u8;
state.reset();
let predicate =
Expand All @@ -461,6 +464,7 @@ fn create_physical_expr_inner(
falsy,
node_to_expr(expression, expr_arena),
lit_count < 2,
is_scalar,
)))
},
AnonymousFunction {
Expand All @@ -469,6 +473,7 @@ fn create_physical_expr_inner(
output_type: _,
options,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
expr_arena
.get(expression)
Expand Down Expand Up @@ -500,6 +505,7 @@ fn create_physical_expr_inner(
state.allow_threading,
schema.cloned(),
output_dtype,
is_scalar,
)))
},
Function {
Expand All @@ -508,6 +514,7 @@ fn create_physical_expr_inner(
options,
..
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
expr_arena
.get(expression)
Expand Down Expand Up @@ -538,6 +545,7 @@ fn create_physical_expr_inner(
state.allow_threading,
schema.cloned(),
output_dtype,
is_scalar,
)))
},
Slice {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-mem-engine/src/executors/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl ProjectionExec {
self.has_windows,
self.options.run_parallel,
)?;
check_expand_literals(selected_cols, df.is_empty(), self.options)
check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)
});

let df = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
Expand All @@ -53,7 +53,7 @@ impl ProjectionExec {
self.has_windows,
self.options.run_parallel,
)?;
check_expand_literals(selected_cols, df.is_empty(), self.options)?
check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)?
};

// this only runs during testing and check if the runtime type matches the predicted schema
Expand Down
Loading