Skip to content

Commit

Permalink
Fix Decimal and Floating type coerce rule
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 31, 2022
1 parent f4d70ac commit af10480
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 11 deletions.
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q11.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Sort: value DESC NULLS FIRST
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
CrossJoin:
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Expand All @@ -9,7 +9,7 @@ Sort: value DESC NULLS FIRST
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name]
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q14.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Projection: CAST(Decimal128(Some(1000000000000000000000),38,19) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Decimal128(38, 19)) AS Decimal128(38, 38)) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Decimal128(38, 38)) AS promo_revenue
Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue
Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
Projection: CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice, part.p_type
Inner Join: lineitem.l_partkey = part.p_partkey
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q20.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ Sort: supplier.s_name ASC NULLS LAST
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value
Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name]
Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate]
34 changes: 34 additions & 0 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,37 @@ async fn decimal_null_array_scalar_comparison() -> Result<()> {
assert_eq!(&DataType::Boolean, actual[0].column(0).data_type());
Ok(())
}

#[tokio::test]
async fn decimal_multiply_float() -> Result<()> {
let ctx = SessionContext::new();
let sql = "select cast(400420638.54 as decimal(12,2));";
let actual = execute_to_batches(&ctx, sql).await;

assert_eq!(
&DataType::Decimal128(12, 2),
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+-----------------------+",
"| Float64(400420638.54) |",
"+-----------------------+",
"| 400420638.54 |",
"+-----------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select cast(400420638.54 as decimal(12,2)) * 1.0;";
let actual = execute_to_batches(&ctx, sql).await;
assert_eq!(&DataType::Float64, actual[0].schema().field(0).data_type());
let expected = vec![
"+------------------------------------+",
"| Float64(400420638.54) * Float64(1) |",
"+------------------------------------+",
"| 400420638.54 |",
"+------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}
8 changes: 4 additions & 4 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,14 @@ order by s_name;
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value
Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")]
Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#
Expand Down Expand Up @@ -443,7 +443,7 @@ order by value desc;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: value DESC NULLS FIRST
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
CrossJoin:
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Expand All @@ -452,7 +452,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,7 @@ impl Subquery {
pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
match plan {
Expr::ScalarSubquery(it) => Ok(it),
Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()),
_ => plan_err!("Could not coerce into ScalarSubquery!"),
}
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ fn mathematics_numerical_coercion(
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
Some(dec_type.clone())
}
(Decimal128(_, _), Float32 | Float64) => Some(Float64),
(Float32 | Float64, Decimal128(_, _)) => Some(Float64),
(Decimal128(_, _), _) => {
let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
match converted_decimal_type {
Expand Down
136 changes: 134 additions & 2 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2574,9 +2574,22 @@ mod tests {
let right_expr = if right.data_type().eq(&op_type) {
col("b", schema)?
} else {
try_cast(col("b", schema)?, schema, op_type)?
try_cast(col("b", schema)?, schema, op_type.clone())?
};
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);

let coerced_schema = Schema::new(vec![
Field::new(
schema.field(0).name(),
op_type.clone(),
schema.field(0).is_nullable(),
),
Field::new(
schema.field(1).name(),
op_type,
schema.field(1).is_nullable(),
),
]);
let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
Expand Down Expand Up @@ -2704,6 +2717,125 @@ mod tests {
Ok(())
}

#[test]
fn arithmetic_decimal_float_expr_test() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let value: i128 = 123;
let decimal_array = Arc::new(create_decimal_array(
&[
Some(value as i128), // 1.23
None,
Some((value - 1) as i128), // 1.22
Some((value + 1) as i128), // 1.24
],
10,
2,
)) as ArrayRef;
let float64_array = Arc::new(Float64Array::from(vec![
Some(123.0),
Some(122.0),
Some(123.0),
Some(124.0),
])) as ArrayRef;

// add: float64 array add decimal array
let expect = Arc::new(Float64Array::from(vec![
Some(124.23),
None,
Some(124.22),
Some(125.24),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Plus,
expect,
)
.unwrap();

// subtract: decimal array subtract float64 array
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let expect = Arc::new(Float64Array::from(vec![
Some(121.77),
None,
Some(121.78),
Some(122.76),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Minus,
expect,
)
.unwrap();

// multiply: decimal array multiply float64 array
let expect = Arc::new(Float64Array::from(vec![
Some(151.29),
None,
Some(150.06),
Some(153.76),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Multiply,
expect,
)
.unwrap();

// divide: float64 array divide decimal array
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let expect = Arc::new(Float64Array::from(vec![
Some(100.0),
None,
Some(100.81967213114754),
Some(100.0),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Divide,
expect,
)
.unwrap();

// modulus: float64 array modulus decimal array
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let expect = Arc::new(Float64Array::from(vec![
Some(1.7763568394002505e-15),
None,
Some(1.0000000000000027),
Some(8.881784197001252e-16),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Modulo,
expect,
)
.unwrap();

Ok(())
}

#[test]
fn bitwise_array_test() -> Result<()> {
let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
Expand Down

0 comments on commit af10480

Please sign in to comment.