|
18 | 18 | //! Math function: `log()`. |
19 | 19 |
|
20 | 20 | use arrow::datatypes::DataType; |
21 | | -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; |
| 21 | +use datafusion_common::{ |
| 22 | + exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, |
| 23 | + ScalarValue, |
| 24 | +}; |
22 | 25 | use datafusion_expr::expr::ScalarFunction; |
23 | 26 | use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; |
24 | | -use datafusion_expr::{ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition}; |
| 27 | +use datafusion_expr::{ |
| 28 | + lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition, |
| 29 | +}; |
25 | 30 |
|
26 | 31 | use arrow::array::{ArrayRef, Float32Array, Float64Array}; |
27 | 32 | use datafusion_expr::TypeSignature::*; |
@@ -146,51 +151,74 @@ impl ScalarUDFImpl for LogFunc { |
146 | 151 | /// 3. Log(a, a) ===> 1 |
147 | 152 | fn simplify( |
148 | 153 | &self, |
149 | | - args: Vec<Expr>, |
| 154 | + mut args: Vec<Expr>, |
150 | 155 | info: &dyn SimplifyInfo, |
151 | 156 | ) -> Result<ExprSimplifyResult> { |
152 | | - let mut number = &args[0]; |
153 | | - let mut base = |
154 | | - &Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?); |
155 | | - if args.len() == 2 { |
156 | | - base = &args[0]; |
157 | | - number = &args[1]; |
| 157 | + // Args are either |
| 158 | + // log(number) |
| 159 | + // log(base, number) |
| 160 | + let num_args = args.len(); |
| 161 | + if num_args > 2 { |
| 162 | + return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}"); |
158 | 163 | } |
| 164 | + let number = args.pop().ok_or_else(|| { |
| 165 | + plan_datafusion_err!("Expected log to have 1 or 2 arguments, got 0") |
| 166 | + })?; |
| 167 | + let number_datatype = info.get_data_type(&number)?; |
| 168 | + // default to base 10 |
| 169 | + let base = if let Some(base) = args.pop() { |
| 170 | + base |
| 171 | + } else { |
| 172 | + lit(ScalarValue::new_ten(&number_datatype)?) |
| 173 | + }; |
159 | 174 |
|
160 | 175 | match number { |
161 | | - Expr::Literal(value) |
162 | | - if value == &ScalarValue::new_one(&info.get_data_type(number)?)? => |
163 | | - { |
164 | | - Ok(ExprSimplifyResult::Simplified(Expr::Literal( |
165 | | - ScalarValue::new_zero(&info.get_data_type(base)?)?, |
166 | | - ))) |
| 176 | + Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { |
| 177 | + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( |
| 178 | + &info.get_data_type(&base)?, |
| 179 | + )?))) |
167 | 180 | } |
168 | | - Expr::ScalarFunction(ScalarFunction { |
169 | | - func_def: ScalarFunctionDefinition::UDF(fun), |
170 | | - args, |
171 | | - }) if base == &args[0] |
172 | | - && fun |
173 | | - .as_ref() |
174 | | - .inner() |
175 | | - .as_any() |
176 | | - .downcast_ref::<PowerFunc>() |
177 | | - .is_some() => |
| 181 | + Expr::ScalarFunction(ScalarFunction { func_def, mut args }) |
| 182 | + if is_pow(&func_def) && args.len() == 2 && base == args[0] => |
178 | 183 | { |
179 | | - Ok(ExprSimplifyResult::Simplified(args[1].clone())) |
| 184 | + let b = args.pop().unwrap(); // length checked above |
| 185 | + Ok(ExprSimplifyResult::Simplified(b)) |
180 | 186 | } |
181 | | - _ => { |
| 187 | + number => { |
182 | 188 | if number == base { |
183 | | - Ok(ExprSimplifyResult::Simplified(Expr::Literal( |
184 | | - ScalarValue::new_one(&info.get_data_type(number)?)?, |
185 | | - ))) |
| 189 | + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( |
| 190 | + &number_datatype, |
| 191 | + )?))) |
186 | 192 | } else { |
| 193 | + let args = match num_args { |
| 194 | + 1 => vec![number], |
| 195 | + 2 => vec![number, base], |
| 196 | + _ => { |
| 197 | + return internal_err!( |
| 198 | + "Unexpected number of arguments in log::simplify" |
| 199 | + ) |
| 200 | + } |
| 201 | + }; |
187 | 202 | Ok(ExprSimplifyResult::Original(args)) |
188 | 203 | } |
189 | 204 | } |
190 | 205 | } |
191 | 206 | } |
192 | 207 | } |
193 | 208 |
|
| 209 | +/// Returns true if the function is `PowerFunc` |
| 210 | +fn is_pow(func_def: &ScalarFunctionDefinition) -> bool { |
| 211 | + if let ScalarFunctionDefinition::UDF(fun) = func_def { |
| 212 | + fun.as_ref() |
| 213 | + .inner() |
| 214 | + .as_any() |
| 215 | + .downcast_ref::<PowerFunc>() |
| 216 | + .is_some() |
| 217 | + } else { |
| 218 | + false |
| 219 | + } |
| 220 | +} |
| 221 | + |
194 | 222 | #[cfg(test)] |
195 | 223 | mod tests { |
196 | 224 | use datafusion_common::cast::{as_float32_array, as_float64_array}; |
|
0 commit comments