Skip to content

Commit

Permalink
refactor(rust): Recursively evaluate is_elementwise for function expr…
Browse files Browse the repository at this point in the history
…essions (#18385)
  • Loading branch information
orlp authored Aug 27, 2024
1 parent 6f5851d commit 25b159d
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 62 deletions.
68 changes: 37 additions & 31 deletions crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,49 @@ fn unique_column_name() -> ColumnName {
format!("__POLARS_STMP_{idx}").into()
}

pub(crate) struct ExprCache {
is_elementwise: PlHashMap<Node, bool>,
is_input_independent: PlHashMap<Node, bool>,
}

impl ExprCache {
pub fn with_capacity(capacity: usize) -> Self {
Self {
is_elementwise: PlHashMap::with_capacity(capacity),
is_input_independent: PlHashMap::with_capacity(capacity),
}
}
}

struct LowerExprContext<'a> {
expr_arena: &'a mut Arena<AExpr>,
phys_sm: &'a mut SlotMap<PhysNodeKey, PhysNode>,
is_elementwise_cache: PlHashMap<Node, bool>,
is_input_independent_cache: PlHashMap<Node, bool>,
cache: &'a mut ExprCache,
}

#[recursive::recursive]
fn is_elementwise_rec(
pub(crate) fn is_elementwise(
expr_key: IRNodeKey,
arena: &Arena<AExpr>,
cache: &mut PlHashMap<IRNodeKey, bool>,
cache: &mut ExprCache,
) -> bool {
if let Some(ret) = cache.get(&expr_key) {
if let Some(ret) = cache.is_elementwise.get(&expr_key) {
return *ret;
}

let ret = match arena.get(expr_key) {
AExpr::Explode(_) => false,
AExpr::Alias(inner, _) => is_elementwise_rec(*inner, arena, cache),
AExpr::Alias(inner, _) => is_elementwise(*inner, arena, cache),
AExpr::Column(_) => true,
AExpr::Literal(lit) => !matches!(lit, LiteralValue::Series(_) | LiteralValue::Range { .. }),
AExpr::BinaryExpr { left, op: _, right } => {
is_elementwise_rec(*left, arena, cache) && is_elementwise_rec(*right, arena, cache)
is_elementwise(*left, arena, cache) && is_elementwise(*right, arena, cache)
},
AExpr::Cast {
expr,
data_type: _,
options: _,
} => is_elementwise_rec(*expr, arena, cache),
} => is_elementwise(*expr, arena, cache),
AExpr::Sort { .. } | AExpr::SortBy { .. } | AExpr::Gather { .. } => false,
AExpr::Filter { .. } => false,
AExpr::Agg(_) => false,
Expand All @@ -63,40 +76,33 @@ fn is_elementwise_rec(
truthy,
falsy,
} => {
is_elementwise_rec(*predicate, arena, cache)
&& is_elementwise_rec(*truthy, arena, cache)
&& is_elementwise_rec(*falsy, arena, cache)
is_elementwise(*predicate, arena, cache)
&& is_elementwise(*truthy, arena, cache)
&& is_elementwise(*falsy, arena, cache)
},
AExpr::AnonymousFunction {
input: _,
input,
function: _,
output_type: _,
options,
} => options.is_elementwise(),
AExpr::Function {
}
| AExpr::Function {
input,
function,
function: _,
options,
} => match function {
FunctionExpr::AsStruct => input
.iter()
.all(|expr| is_elementwise_rec(expr.node(), arena, cache)),
_ => options.is_elementwise(),
} => {
options.is_elementwise() && input.iter().all(|e| is_elementwise(e.node(), arena, cache))
},

AExpr::Window { .. } => false,
AExpr::Slice { .. } => false,
AExpr::Len => false,
};

cache.insert(expr_key, ret);
cache.is_elementwise.insert(expr_key, ret);
ret
}

fn is_elementwise(expr_key: IRNodeKey, ctx: &mut LowerExprContext) -> bool {
is_elementwise_rec(expr_key, ctx.expr_arena, &mut ctx.is_elementwise_cache)
}

#[recursive::recursive]
fn is_input_independent_rec(
expr_key: IRNodeKey,
Expand Down Expand Up @@ -212,7 +218,7 @@ fn is_input_independent(expr_key: IRNodeKey, ctx: &mut LowerExprContext) -> bool
is_input_independent_rec(
expr_key,
ctx.expr_arena,
&mut ctx.is_input_independent_cache,
&mut ctx.cache.is_input_independent,
)
}

Expand Down Expand Up @@ -389,7 +395,7 @@ fn lower_exprs_with_ctx(
let mut transformed_exprs = Vec::with_capacity(exprs.len());

for expr in exprs.iter().copied() {
if is_elementwise(expr, ctx) {
if is_elementwise(expr, ctx.expr_arena, ctx.cache) {
if !is_input_independent(expr, ctx) {
input_nodes.insert(input);
}
Expand Down Expand Up @@ -716,12 +722,12 @@ pub fn lower_exprs(
exprs: &[ExprIR],
expr_arena: &mut Arena<AExpr>,
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
expr_cache: &mut ExprCache,
) -> PolarsResult<(PhysNodeKey, Vec<ExprIR>)> {
let mut ctx = LowerExprContext {
expr_arena,
phys_sm,
is_elementwise_cache: PlHashMap::new(),
is_input_independent_cache: PlHashMap::new(),
cache: expr_cache,
};
let node_exprs = exprs.iter().map(|e| e.node()).collect_vec();
let (transformed_input, transformed_exprs) =
Expand All @@ -740,12 +746,12 @@ pub fn build_select_node(
exprs: &[ExprIR],
expr_arena: &mut Arena<AExpr>,
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
expr_cache: &mut ExprCache,
) -> PolarsResult<PhysNodeKey> {
let mut ctx = LowerExprContext {
expr_arena,
phys_sm,
is_elementwise_cache: PlHashMap::new(),
is_input_independent_cache: PlHashMap::new(),
cache: expr_cache,
};
build_select_node_with_ctx(input, exprs, &mut ctx)
}
Loading

0 comments on commit 25b159d

Please sign in to comment.