-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Refactor log() signature to use coercion API + fixes
#18519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: add the tests Done
| logical_plan | ||
| 01)Sort: log_c11_base_c12 ASC NULLS LAST | ||
| 02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c11_base_c12 | ||
| 02)--Projection: log(aggregate_test_100.c12, aggregate_test_100.c11) AS log_c11_base_c12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c12 is f64, c11 is f32; now log() will handle casting them to the right types internally for it's implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is probably more efficient than trying to cast externally anyways
| # TODO: this should be 116.267483321058, error with native decimal log impl | ||
| query R | ||
| select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); | ||
| ---- | ||
| 116.267483321058 | ||
| 116 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting; we weren't actually making use of the native decimal implementation for log here apparently. The decimal value was actually being casted to float. Fixing the signature to have decimals take precedence (so they don't get casted to float) reveals a bug in our decimal log implementation.
Note to self: raise separate issue for this
| // Null propagation | ||
| if arg_types.iter().any(|dt| dt.is_null()) { | ||
| return Ok(ExprSimplifyResult::Simplified(lit( | ||
| ScalarValue::Null.cast_to(&return_type)? | ||
| ))); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So any log(null, x), log(x, null) or log(null) should be immediately simplified to null.
I'll add a test case to show this actually has an effect.
- Added by 0da79bd
|
|
||
| impl LogFunc { | ||
| pub fn new() -> Self { | ||
| // Converts decimals & integers to float64, accepting other floats as is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main fix here for signature; to me this is more readable than the previous signature, and importantly it accepts any decimals regardless of precision/scale.
Note to self: add tests for decimals of different precision/scale if they don't exist
- Added by 0da79bd
|
|
||
| // Support overloaded log(base, x) and log(x) which defaults to log(10, x) | ||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| let args = ColumnarValue::values_to_arrays(&args.args)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was converting all the args to array, losing optimization opportunity if the base was a scalar. Fix so we maintain the scalar throughout. the value will need to be an array because calculate_binary_math requires its left argument (value for us) to be an array anyway:
datafusion/datafusion/functions/src/utils.rs
Lines 131 to 135 in a5eb912
| pub fn calculate_binary_math<L, R, O, F>( | |
| left: &dyn Array, | |
| right: &ColumnarValue, | |
| fun: F, | |
| ) -> Result<Arc<PrimitiveArray<O>>> |
| Float64Type, | ||
| _, | ||
| >(value, &base, |x, b| Ok(f64::log(x, b)))?, | ||
| DataType::Int32 => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed arms for Int32 & Int64 since they should now be casted to Float64 by the new signature; this is functionally equivalent to what the implementations for Int32 & Int64 did anyway
| # Null propagation for log | ||
| query TT | ||
| EXPLAIN SELECT log(NULL, c2) from aggregate_simple; | ||
| ---- | ||
| logical_plan | ||
| 01)Projection: Float64(NULL) AS log(NULL,aggregate_simple.c2) | ||
| 02)--TableScan: aggregate_simple projection=[] | ||
| physical_plan | ||
| 01)ProjectionExec: expr=[NULL as log(NULL,aggregate_simple.c2)] | ||
| 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_simple.csv]]}, file_type=csv, has_header=true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On main this fails like so:
1. query result mismatch:
[SQL] EXPLAIN SELECT log(NULL, c2) from aggregate_simple;
[Diff] (-expected|+actual)
logical_plan
- 01)Projection: Float64(NULL) AS log(NULL,aggregate_simple.c2)
- 02)--TableScan: aggregate_simple projection=[]
+ 01)Projection: log(Float64(NULL), aggregate_simple.c2) AS log(NULL,aggregate_simple.c2)
+ 02)--TableScan: aggregate_simple projection=[c2]
physical_plan
- 01)ProjectionExec: expr=[NULL as log(NULL,aggregate_simple.c2)]
- 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_simple.csv]]}, file_type=csv, has_header=true
+ 01)ProjectionExec: expr=[log(NULL, c2@0) as log(NULL,aggregate_simple.c2)]
+ 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+ 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_simple.csv]]}, projection=[c2], file_type=csv, has_header=true
at /Users/jeffrey/Code/datafusion/datafusion/sqllogictest/test_files/math.slt:710- On main it still computes the log; in this PR the log function gets optimized away entirely
alamb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logical_plan | ||
| 01)Sort: log_c11_base_c12 ASC NULLS LAST | ||
| 02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c11_base_c12 | ||
| 02)--Projection: log(aggregate_test_100.c12, aggregate_test_100.c11) AS log_c11_base_c12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is probably more efficient than trying to cast externally anyways
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes apache#123` indicates that this PR will close issue apache#123. --> Part of apache#14763 and apache#14760 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Current `log()` signature has some drawbacks: https://github.com/apache/datafusion/blob/a5eb9121ccf802dda547897155403b08a4fbf774/datafusion/functions/src/math/log.rs#L78-L105 - A bit nasty to look at: mixes numeric with exact float/int with exact decimal (of exact precision and scale) - Can't accommodate arbitrary decimals of any precision/scale (this is true for other functions too) Aim of this PR is to refactor it to use the coercion API, uplifting the API where necessary to make this possible. This simplifies the signature in code, whilst not losing flexibility. Also other minor refactors are included to log. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> New `TypeSignatureClass` variants: Float, Decimal & Numeric Refactor `log()` signature to be more in line with it's supported implementations. Fix issue in `log()` where `ColumnarValue::Scalar`s were being lost as `ColumnarValue::Array`s for the base. Support null propagation in `simplify()` for `log()`. ~~Fix issue with `calculate_binary_math` where it wasn't casting scalars.~~ ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added new tests. - Tests for float16, decimal32, decimal64, decimals with different scales/precisions - Test for null propagation (ensure use array input to avoid function inlining) ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes apache#123` indicates that this PR will close issue apache#123. --> Part of apache#14763 and apache#14760 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Current `log()` signature has some drawbacks: https://github.com/apache/datafusion/blob/a5eb9121ccf802dda547897155403b08a4fbf774/datafusion/functions/src/math/log.rs#L78-L105 - A bit nasty to look at: mixes numeric with exact float/int with exact decimal (of exact precision and scale) - Can't accommodate arbitrary decimals of any precision/scale (this is true for other functions too) Aim of this PR is to refactor it to use the coercion API, uplifting the API where necessary to make this possible. This simplifies the signature in code, whilst not losing flexibility. Also other minor refactors are included to log. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> New `TypeSignatureClass` variants: Float, Decimal & Numeric Refactor `log()` signature to be more in line with it's supported implementations. Fix issue in `log()` where `ColumnarValue::Scalar`s were being lost as `ColumnarValue::Array`s for the base. Support null propagation in `simplify()` for `log()`. ~~Fix issue with `calculate_binary_math` where it wasn't casting scalars.~~ ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added new tests. - Tests for float16, decimal32, decimal64, decimals with different scales/precisions - Test for null propagation (ensure use array input to avoid function inlining) ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
Which issue does this PR close?
Part of #14763 and #14760
Rationale for this change
Current
log()signature has some drawbacks:datafusion/datafusion/functions/src/math/log.rs
Lines 78 to 105 in a5eb912
Aim of this PR is to refactor it to use the coercion API, uplifting the API where necessary to make this possible. This simplifies the signature in code, whilst not losing flexibility.
Also other minor refactors are included to log.
What changes are included in this PR?
New
TypeSignatureClassvariants: Float, Decimal & NumericRefactor
log()signature to be more in line with it's supported implementations.Fix issue in
log()whereColumnarValue::Scalars were being lost asColumnarValue::Arrays for the base.Support null propagation in
simplify()forlog().Fix issue withcalculate_binary_mathwhere it wasn't casting scalars.Are these changes tested?
Added new tests.
Are there any user-facing changes?
No.