Skip to content

Commit 032117a

Browse files
AdamGSJefffrey
andauthored
More decimal 32/64 support - type coercsion and misc gaps (#17808)
* More small decimal support * CR comments Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> * Add tests and cleanup some code --------- Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent 7d6d553 commit 032117a

File tree

7 files changed

+353
-22
lines changed

7 files changed

+353
-22
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,12 @@ impl ScalarValue {
13621362
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
13631363
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
13641364
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
1365+
DataType::Decimal32(precision, scale) => {
1366+
ScalarValue::Decimal32(Some(0), *precision, *scale)
1367+
}
1368+
DataType::Decimal64(precision, scale) => {
1369+
ScalarValue::Decimal64(Some(0), *precision, *scale)
1370+
}
13651371
DataType::Decimal128(precision, scale) => {
13661372
ScalarValue::Decimal128(Some(0), *precision, *scale)
13671373
}

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> {
327327

328328
// TODO Move the rest inside of BinaryTypeCoercer
329329

330+
fn is_decimal(data_type: &DataType) -> bool {
331+
matches!(
332+
data_type,
333+
DataType::Decimal32(..)
334+
| DataType::Decimal64(..)
335+
| DataType::Decimal128(..)
336+
| DataType::Decimal256(..)
337+
)
338+
}
339+
330340
/// Coercion rules for mathematics operators between decimal and non-decimal types.
331341
fn math_decimal_coercion(
332342
lhs_type: &DataType,
@@ -357,6 +367,15 @@ fn math_decimal_coercion(
357367
| (Decimal256(_, _), Decimal256(_, _)) => {
358368
Some((lhs_type.clone(), rhs_type.clone()))
359369
}
370+
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
371+
(lhs, rhs)
372+
if is_decimal(lhs)
373+
&& is_decimal(rhs)
374+
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
375+
{
376+
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
377+
Some((coerced_type.clone(), coerced_type))
378+
}
360379
// Unlike with comparison we don't coerce to a decimal in the case of floating point
361380
// numbers, instead falling back to floating point arithmetic instead
362381
(
@@ -953,21 +972,92 @@ pub fn binary_numeric_coercion(
953972
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
954973
use arrow::datatypes::DataType::*;
955974

975+
// Prefer decimal data type over floating point for comparison operation
956976
match (lhs_type, rhs_type) {
957-
// Prefer decimal data type over floating point for comparison operation
958-
(Decimal128(_, _), Decimal128(_, _)) => {
977+
// Same decimal types
978+
(lhs_type, rhs_type)
979+
if is_decimal(lhs_type)
980+
&& is_decimal(rhs_type)
981+
&& std::mem::discriminant(lhs_type)
982+
== std::mem::discriminant(rhs_type) =>
983+
{
959984
get_wider_decimal_type(lhs_type, rhs_type)
960985
}
961-
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
962-
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
963-
(Decimal256(_, _), Decimal256(_, _)) => {
964-
get_wider_decimal_type(lhs_type, rhs_type)
986+
// Mismatched decimal types
987+
(lhs_type, rhs_type)
988+
if is_decimal(lhs_type)
989+
&& is_decimal(rhs_type)
990+
&& std::mem::discriminant(lhs_type)
991+
!= std::mem::discriminant(rhs_type) =>
992+
{
993+
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
994+
}
995+
// Decimal + non-decimal types
996+
(Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => {
997+
get_common_decimal_type(lhs_type, rhs_type)
998+
}
999+
(_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
1000+
get_common_decimal_type(rhs_type, lhs_type)
9651001
}
966-
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
967-
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
9681002
(_, _) => None,
9691003
}
9701004
}
1005+
/// Handle cross-variant decimal widening by choosing the larger variant
1006+
fn get_wider_decimal_type_cross_variant(
1007+
lhs_type: &DataType,
1008+
rhs_type: &DataType,
1009+
) -> Option<DataType> {
1010+
use arrow::datatypes::DataType::*;
1011+
1012+
let (p1, s1) = match lhs_type {
1013+
Decimal32(p, s) => (*p, *s),
1014+
Decimal64(p, s) => (*p, *s),
1015+
Decimal128(p, s) => (*p, *s),
1016+
Decimal256(p, s) => (*p, *s),
1017+
_ => return None,
1018+
};
1019+
1020+
let (p2, s2) = match rhs_type {
1021+
Decimal32(p, s) => (*p, *s),
1022+
Decimal64(p, s) => (*p, *s),
1023+
Decimal128(p, s) => (*p, *s),
1024+
Decimal256(p, s) => (*p, *s),
1025+
_ => return None,
1026+
};
1027+
1028+
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
1029+
let s = s1.max(s2);
1030+
let range = (p1 as i8 - s1).max(p2 as i8 - s2);
1031+
let required_precision = (range + s) as u8;
1032+
1033+
// Choose the larger variant between the two input types, while making sure we don't overflow the precision.
1034+
match (lhs_type, rhs_type) {
1035+
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _))
1036+
if required_precision <= DECIMAL64_MAX_PRECISION =>
1037+
{
1038+
Some(Decimal64(required_precision, s))
1039+
}
1040+
(Decimal32(_, _), Decimal128(_, _))
1041+
| (Decimal128(_, _), Decimal32(_, _))
1042+
| (Decimal64(_, _), Decimal128(_, _))
1043+
| (Decimal128(_, _), Decimal64(_, _))
1044+
if required_precision <= DECIMAL128_MAX_PRECISION =>
1045+
{
1046+
Some(Decimal128(required_precision, s))
1047+
}
1048+
(Decimal32(_, _), Decimal256(_, _))
1049+
| (Decimal256(_, _), Decimal32(_, _))
1050+
| (Decimal64(_, _), Decimal256(_, _))
1051+
| (Decimal256(_, _), Decimal64(_, _))
1052+
| (Decimal128(_, _), Decimal256(_, _))
1053+
| (Decimal256(_, _), Decimal128(_, _))
1054+
if required_precision <= DECIMAL256_MAX_PRECISION =>
1055+
{
1056+
Some(Decimal256(required_precision, s))
1057+
}
1058+
_ => None,
1059+
}
1060+
}
9711061

9721062
/// Coerce `lhs_type` and `rhs_type` to a common type.
9731063
fn get_common_decimal_type(
@@ -976,7 +1066,15 @@ fn get_common_decimal_type(
9761066
) -> Option<DataType> {
9771067
use arrow::datatypes::DataType::*;
9781068
match decimal_type {
979-
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => {
1069+
Decimal32(_, _) => {
1070+
let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?;
1071+
get_wider_decimal_type(decimal_type, &other_decimal_type)
1072+
}
1073+
Decimal64(_, _) => {
1074+
let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?;
1075+
get_wider_decimal_type(decimal_type, &other_decimal_type)
1076+
}
1077+
Decimal128(_, _) => {
9801078
let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?;
9811079
get_wider_decimal_type(decimal_type, &other_decimal_type)
9821080
}
@@ -988,7 +1086,7 @@ fn get_common_decimal_type(
9881086
}
9891087
}
9901088

991-
/// Returns a `DataType::Decimal128` that can store any value from either
1089+
/// Returns a decimal [`DataType`] variant that can store any value from either
9921090
/// `lhs_decimal_type` and `rhs_decimal_type`
9931091
///
9941092
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
@@ -1209,14 +1307,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
12091307
}
12101308

12111309
fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
1212-
DataType::Decimal128(
1310+
DataType::Decimal32(
12131311
DECIMAL32_MAX_PRECISION.min(precision),
12141312
DECIMAL32_MAX_SCALE.min(scale),
12151313
)
12161314
}
12171315

12181316
fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
1219-
DataType::Decimal128(
1317+
DataType::Decimal64(
12201318
DECIMAL64_MAX_PRECISION.min(precision),
12211319
DECIMAL64_MAX_SCALE.min(scale),
12221320
)

datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,133 @@ fn test_coercion_arithmetic_decimal() -> Result<()> {
291291

292292
Ok(())
293293
}
294+
295+
#[test]
296+
fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> {
297+
let test_cases = [
298+
(
299+
DataType::Decimal32(5, 2),
300+
DataType::Decimal64(10, 3),
301+
DataType::Decimal64(10, 3),
302+
DataType::Decimal64(10, 3),
303+
),
304+
(
305+
DataType::Decimal32(7, 1),
306+
DataType::Decimal128(15, 4),
307+
DataType::Decimal128(15, 4),
308+
DataType::Decimal128(15, 4),
309+
),
310+
(
311+
DataType::Decimal32(9, 0),
312+
DataType::Decimal256(20, 5),
313+
DataType::Decimal256(20, 5),
314+
DataType::Decimal256(20, 5),
315+
),
316+
(
317+
DataType::Decimal64(12, 3),
318+
DataType::Decimal128(18, 2),
319+
DataType::Decimal128(19, 3),
320+
DataType::Decimal128(19, 3),
321+
),
322+
(
323+
DataType::Decimal64(15, 4),
324+
DataType::Decimal256(25, 6),
325+
DataType::Decimal256(25, 6),
326+
DataType::Decimal256(25, 6),
327+
),
328+
(
329+
DataType::Decimal128(20, 5),
330+
DataType::Decimal256(30, 8),
331+
DataType::Decimal256(30, 8),
332+
DataType::Decimal256(30, 8),
333+
),
334+
// Reverse order cases
335+
(
336+
DataType::Decimal64(10, 3),
337+
DataType::Decimal32(5, 2),
338+
DataType::Decimal64(10, 3),
339+
DataType::Decimal64(10, 3),
340+
),
341+
(
342+
DataType::Decimal128(15, 4),
343+
DataType::Decimal32(7, 1),
344+
DataType::Decimal128(15, 4),
345+
DataType::Decimal128(15, 4),
346+
),
347+
(
348+
DataType::Decimal256(20, 5),
349+
DataType::Decimal32(9, 0),
350+
DataType::Decimal256(20, 5),
351+
DataType::Decimal256(20, 5),
352+
),
353+
(
354+
DataType::Decimal128(18, 2),
355+
DataType::Decimal64(12, 3),
356+
DataType::Decimal128(19, 3),
357+
DataType::Decimal128(19, 3),
358+
),
359+
(
360+
DataType::Decimal256(25, 6),
361+
DataType::Decimal64(15, 4),
362+
DataType::Decimal256(25, 6),
363+
DataType::Decimal256(25, 6),
364+
),
365+
(
366+
DataType::Decimal256(30, 8),
367+
DataType::Decimal128(20, 5),
368+
DataType::Decimal256(30, 8),
369+
DataType::Decimal256(30, 8),
370+
),
371+
];
372+
373+
for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases {
374+
test_math_decimal_coercion_rule(
375+
lhs_type,
376+
rhs_type,
377+
expected_lhs_type,
378+
expected_rhs_type,
379+
);
380+
}
381+
382+
Ok(())
383+
}
384+
385+
#[test]
386+
fn test_decimal_precision_overflow_cross_variant() -> Result<()> {
387+
// s = max(0, 1) = 1, range = max(76-0, 38-1) = 76, required_precision = 76 + 1 = 77 (overflow)
388+
let result = get_wider_decimal_type_cross_variant(
389+
&DataType::Decimal256(76, 0),
390+
&DataType::Decimal128(38, 1),
391+
);
392+
assert!(result.is_none());
393+
394+
// s = max(0, 10) = 10, range = max(9-0, 18-10) = 9, required_precision = 9 + 10 = 19 (overflow > 18)
395+
let result = get_wider_decimal_type_cross_variant(
396+
&DataType::Decimal32(9, 0),
397+
&DataType::Decimal64(18, 10),
398+
);
399+
assert!(result.is_none());
400+
401+
// s = max(5, 26) = 26, range = max(18-5, 38-26) = 13, required_precision = 13 + 26 = 39 (overflow > 38)
402+
let result = get_wider_decimal_type_cross_variant(
403+
&DataType::Decimal64(18, 5),
404+
&DataType::Decimal128(38, 26),
405+
);
406+
assert!(result.is_none());
407+
408+
// s = max(10, 49) = 49, range = max(38-10, 76-49) = 28, required_precision = 28 + 49 = 77 (overflow > 76)
409+
let result = get_wider_decimal_type_cross_variant(
410+
&DataType::Decimal128(38, 10),
411+
&DataType::Decimal256(76, 49),
412+
);
413+
assert!(result.is_none());
414+
415+
// s = max(2, 3) = 3, range = max(5-2, 10-3) = 7, required_precision = 7 + 3 = 10 (valid <= 18)
416+
let result = get_wider_decimal_type_cross_variant(
417+
&DataType::Decimal32(5, 2),
418+
&DataType::Decimal64(10, 3),
419+
);
420+
assert!(result.is_some());
421+
422+
Ok(())
423+
}

0 commit comments

Comments
 (0)