Skip to content

Commit 0f88a2d

Browse files
alambcomphead
andauthored
Avoid cloning in log::simplify and power::simplify (#10086)
* Avoid cloning in power::simplify * Avoid cloning in log::simplify * Apply suggestions from code review Co-authored-by: comphead <comphead@users.noreply.github.com> --------- Co-authored-by: comphead <comphead@users.noreply.github.com>
1 parent da40cb9 commit 0f88a2d

File tree

2 files changed

+92
-55
lines changed

2 files changed

+92
-55
lines changed

datafusion/functions/src/math/log.rs

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
//! Math function: `log()`.
1919
2020
use arrow::datatypes::DataType;
21-
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
21+
use datafusion_common::{
22+
exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result,
23+
ScalarValue,
24+
};
2225
use datafusion_expr::expr::ScalarFunction;
2326
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
24-
use datafusion_expr::{ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition};
27+
use datafusion_expr::{
28+
lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition,
29+
};
2530

2631
use arrow::array::{ArrayRef, Float32Array, Float64Array};
2732
use datafusion_expr::TypeSignature::*;
@@ -146,51 +151,74 @@ impl ScalarUDFImpl for LogFunc {
146151
/// 3. Log(a, a) ===> 1
147152
fn simplify(
148153
&self,
149-
args: Vec<Expr>,
154+
mut args: Vec<Expr>,
150155
info: &dyn SimplifyInfo,
151156
) -> Result<ExprSimplifyResult> {
152-
let mut number = &args[0];
153-
let mut base =
154-
&Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?);
155-
if args.len() == 2 {
156-
base = &args[0];
157-
number = &args[1];
157+
// Args are either
158+
// log(number)
159+
// log(base, number)
160+
let num_args = args.len();
161+
if num_args > 2 {
162+
return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}");
158163
}
164+
let number = args.pop().ok_or_else(|| {
165+
plan_datafusion_err!("Expected log to have 1 or 2 arguments, got 0")
166+
})?;
167+
let number_datatype = info.get_data_type(&number)?;
168+
// default to base 10
169+
let base = if let Some(base) = args.pop() {
170+
base
171+
} else {
172+
lit(ScalarValue::new_ten(&number_datatype)?)
173+
};
159174

160175
match number {
161-
Expr::Literal(value)
162-
if value == &ScalarValue::new_one(&info.get_data_type(number)?)? =>
163-
{
164-
Ok(ExprSimplifyResult::Simplified(Expr::Literal(
165-
ScalarValue::new_zero(&info.get_data_type(base)?)?,
166-
)))
176+
Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => {
177+
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero(
178+
&info.get_data_type(&base)?,
179+
)?)))
167180
}
168-
Expr::ScalarFunction(ScalarFunction {
169-
func_def: ScalarFunctionDefinition::UDF(fun),
170-
args,
171-
}) if base == &args[0]
172-
&& fun
173-
.as_ref()
174-
.inner()
175-
.as_any()
176-
.downcast_ref::<PowerFunc>()
177-
.is_some() =>
181+
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
182+
if is_pow(&func_def) && args.len() == 2 && base == args[0] =>
178183
{
179-
Ok(ExprSimplifyResult::Simplified(args[1].clone()))
184+
let b = args.pop().unwrap(); // length checked above
185+
Ok(ExprSimplifyResult::Simplified(b))
180186
}
181-
_ => {
187+
number => {
182188
if number == base {
183-
Ok(ExprSimplifyResult::Simplified(Expr::Literal(
184-
ScalarValue::new_one(&info.get_data_type(number)?)?,
185-
)))
189+
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
190+
&number_datatype,
191+
)?)))
186192
} else {
193+
let args = match num_args {
194+
1 => vec![number],
195+
2 => vec![number, base],
196+
_ => {
197+
return internal_err!(
198+
"Unexpected number of arguments in log::simplify"
199+
)
200+
}
201+
};
187202
Ok(ExprSimplifyResult::Original(args))
188203
}
189204
}
190205
}
191206
}
192207
}
193208

209+
/// Returns true if the function is `PowerFunc`
210+
fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {
211+
if let ScalarFunctionDefinition::UDF(fun) = func_def {
212+
fun.as_ref()
213+
.inner()
214+
.as_any()
215+
.downcast_ref::<PowerFunc>()
216+
.is_some()
217+
} else {
218+
false
219+
}
220+
}
221+
194222
#[cfg(test)]
195223
mod tests {
196224
use datafusion_common::cast::{as_float32_array, as_float64_array};

datafusion/functions/src/math/power.rs

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
//! Math function: `power()`.
1919
2020
use arrow::datatypes::DataType;
21-
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
21+
use datafusion_common::{
22+
exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue,
23+
};
2224
use datafusion_expr::expr::ScalarFunction;
2325
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
2426
use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition};
@@ -118,43 +120,50 @@ impl ScalarUDFImpl for PowerFunc {
118120
/// 3. Power(a, Log(a, b)) ===> b
119121
fn simplify(
120122
&self,
121-
args: Vec<Expr>,
123+
mut args: Vec<Expr>,
122124
info: &dyn SimplifyInfo,
123125
) -> Result<ExprSimplifyResult> {
124-
let base = &args[0];
125-
let exponent = &args[1];
126-
126+
let exponent = args.pop().ok_or_else(|| {
127+
plan_datafusion_err!("Expected power to have 2 arguments, got 0")
128+
})?;
129+
let base = args.pop().ok_or_else(|| {
130+
plan_datafusion_err!("Expected power to have 2 arguments, got 1")
131+
})?;
132+
133+
let exponent_type = info.get_data_type(&exponent)?;
127134
match exponent {
128-
Expr::Literal(value)
129-
if value == &ScalarValue::new_zero(&info.get_data_type(exponent)?)? =>
130-
{
135+
Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => {
131136
Ok(ExprSimplifyResult::Simplified(Expr::Literal(
132-
ScalarValue::new_one(&info.get_data_type(base)?)?,
137+
ScalarValue::new_one(&info.get_data_type(&base)?)?,
133138
)))
134139
}
135-
Expr::Literal(value)
136-
if value == &ScalarValue::new_one(&info.get_data_type(exponent)?)? =>
137-
{
138-
Ok(ExprSimplifyResult::Simplified(base.clone()))
140+
Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => {
141+
Ok(ExprSimplifyResult::Simplified(base))
139142
}
140-
Expr::ScalarFunction(ScalarFunction {
141-
func_def: ScalarFunctionDefinition::UDF(fun),
142-
args,
143-
}) if base == &args[0]
144-
&& fun
145-
.as_ref()
146-
.inner()
147-
.as_any()
148-
.downcast_ref::<LogFunc>()
149-
.is_some() =>
143+
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
144+
if is_log(&func_def) && args.len() == 2 && base == args[0] =>
150145
{
151-
Ok(ExprSimplifyResult::Simplified(args[1].clone()))
146+
let b = args.pop().unwrap(); // length checked above
147+
Ok(ExprSimplifyResult::Simplified(b))
152148
}
153-
_ => Ok(ExprSimplifyResult::Original(args)),
149+
_ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
154150
}
155151
}
156152
}
157153

154+
/// Return true if this function call is a call to `Log`
155+
fn is_log(func_def: &ScalarFunctionDefinition) -> bool {
156+
if let ScalarFunctionDefinition::UDF(fun) = func_def {
157+
fun.as_ref()
158+
.inner()
159+
.as_any()
160+
.downcast_ref::<LogFunc>()
161+
.is_some()
162+
} else {
163+
false
164+
}
165+
}
166+
158167
#[cfg(test)]
159168
mod tests {
160169
use datafusion_common::cast::{as_float64_array, as_int64_array};

0 commit comments

Comments
 (0)