Skip to content

Commit 0da79bd

Browse files
committed
Tests & fix casting dec32/64 to dec128
1 parent 2ce83f0 commit 0da79bd

File tree

4 files changed

+84
-7
lines changed

4 files changed

+84
-7
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ impl ScalarValue {
17341734
) {
17351735
return _internal_err!("Invalid precision and scale {err}");
17361736
}
1737-
if *scale <= 0 {
1737+
if *scale < 0 {
17381738
return _internal_err!("Negative scale is not supported");
17391739
}
17401740
match 10_i32.checked_pow((*scale + 1) as u32) {
@@ -1750,7 +1750,7 @@ impl ScalarValue {
17501750
) {
17511751
return _internal_err!("Invalid precision and scale {err}");
17521752
}
1753-
if *scale <= 0 {
1753+
if *scale < 0 {
17541754
return _internal_err!("Negative scale is not supported");
17551755
}
17561756
match i64::from(10).checked_pow((*scale + 1) as u32) {

datafusion/functions/src/math/log.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use super::power::PowerFunc;
2323

2424
use crate::utils::{calculate_binary_math, decimal128_to_i128};
2525
use arrow::array::{Array, ArrayRef};
26+
use arrow::compute::kernels::cast;
2627
use arrow::datatypes::{
2728
DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
2829
};
@@ -224,9 +225,15 @@ impl ScalarUDFImpl for LogFunc {
224225
}
225226
// TODO: native log support for decimal 32 & 64; right now upcast
226227
// to decimal128 to calculate
227-
DataType::Decimal32(_, scale)
228-
| DataType::Decimal64(_, scale)
229-
| DataType::Decimal128(_, scale) => {
228+
DataType::Decimal32(precision, scale)
229+
| DataType::Decimal64(precision, scale) => {
230+
calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
231+
&cast(&value, &DataType::Decimal128(*precision, *scale))?,
232+
&base,
233+
|value, base| log_decimal128(value, *scale, base),
234+
)?
235+
}
236+
DataType::Decimal128(_, scale) => {
230237
calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
231238
&value,
232239
&base,
@@ -349,6 +356,17 @@ mod tests {
349356
use datafusion_expr::execution_props::ExecutionProps;
350357
use datafusion_expr::simplify::SimplifyContext;
351358

359+
#[test]
360+
fn test_log_decimal_native() {
361+
let value = 10_i128.pow(35);
362+
assert_eq!((value as f64).log2(), 116.26748332105768);
363+
assert_eq!(
364+
log_decimal128(value, 0, 2.0).unwrap(),
365+
// TODO: see we're losing our decimal points compared to above
366+
116.0
367+
);
368+
}
369+
352370
#[test]
353371
fn test_log_invalid_base_type() {
354372
let arg_fields = vec![

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,15 +794,47 @@ select 100000000000000000000000000000000000::decimal(38,0), arrow_typeof(1000000
794794
----
795795
100000000000000000000000000000000000 Decimal128(38, 0)
796796

797+
# log for small decimal32
798+
query R
799+
select log(arrow_cast(100, 'Decimal32(9, 0)'));
800+
----
801+
2
802+
803+
query R
804+
select log(arrow_cast(100, 'Decimal32(9, 2)'));
805+
----
806+
2
807+
808+
# log for small decimal64
809+
query R
810+
select log(arrow_cast(100, 'Decimal64(18, 0)'));
811+
----
812+
2
813+
814+
query R
815+
select log(arrow_cast(100, 'Decimal64(18, 2)'));
816+
----
817+
2
818+
797819
# log for small decimal128
798820
query R
799-
select log(100::decimal(38,0));
821+
select log(arrow_cast(100, 'Decimal128(38, 0)'));
822+
----
823+
2
824+
825+
query R
826+
select log(arrow_cast(100, 'Decimal128(38, 2)'));
800827
----
801828
2
802829

803830
# log for small decimal256
804831
query R
805-
select log(100::decimal(76,0));
832+
select log(arrow_cast(100, 'Decimal256(76, 0)'));
833+
----
834+
2
835+
836+
query R
837+
select log(arrow_cast(100, 'Decimal256(76, 2)'));
806838
----
807839
2
808840

datafusion/sqllogictest/test_files/math.slt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,3 +705,30 @@ select FACTORIAL(350943270);
705705

706706
statement ok
707707
drop table signed_integers
708+
709+
# Null propagation for log
710+
query TT
711+
EXPLAIN SELECT log(NULL, c2) from aggregate_simple;
712+
----
713+
logical_plan
714+
01)Projection: Float64(NULL) AS log(NULL,aggregate_simple.c2)
715+
02)--TableScan: aggregate_simple projection=[]
716+
physical_plan
717+
01)ProjectionExec: expr=[NULL as log(NULL,aggregate_simple.c2)]
718+
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_simple.csv]]}, file_type=csv, has_header=true
719+
720+
# Float 16/32/64 for log
721+
query RT
722+
SELECT log(2.5, arrow_cast(10.9, 'Float16')), arrow_typeof(log(2.5, arrow_cast(10.9, 'Float16')));
723+
----
724+
2.6074219 Float16
725+
726+
query RT
727+
SELECT log(2.5, 10.9::float), arrow_typeof(log(2.5, 10.9::float));
728+
----
729+
2.606992 Float32
730+
731+
query RT
732+
SELECT log(2.5, 10.9::double), arrow_typeof(log(2.5, 10.9::double));
733+
----
734+
2.606992198152 Float64

0 commit comments

Comments
 (0)