Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Decimal and Floating type coerce rule #4038

Merged
merged 1 commit into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q11.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Sort: value DESC NULLS FIRST
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
CrossJoin:
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Expand All @@ -9,7 +9,7 @@ Sort: value DESC NULLS FIRST
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name]
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q14.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Projection: CAST(Decimal128(Some(1000000000000000000000),38,19) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Decimal128(38, 19)) AS Decimal128(38, 38)) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Decimal128(38, 38)) AS promo_revenue
Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue
Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
Projection: CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice, part.p_type
Inner Join: lineitem.l_partkey = part.p_partkey
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q20.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ Sort: supplier.s_name ASC NULLS LAST
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value
Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name]
Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate]
34 changes: 34 additions & 0 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,37 @@ async fn decimal_null_array_scalar_comparison() -> Result<()> {
assert_eq!(&DataType::Boolean, actual[0].column(0).data_type());
Ok(())
}

#[tokio::test]
async fn decimal_multiply_float() -> Result<()> {
let ctx = SessionContext::new();
let sql = "select cast(400420638.54 as decimal(12,2));";
let actual = execute_to_batches(&ctx, sql).await;

assert_eq!(
&DataType::Decimal128(12, 2),
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+-----------------------+",
"| Float64(400420638.54) |",
"+-----------------------+",
"| 400420638.54 |",
"+-----------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select cast(400420638.54 as decimal(12,2)) * 1.0;";
let actual = execute_to_batches(&ctx, sql).await;
assert_eq!(&DataType::Float64, actual[0].schema().field(0).data_type());
let expected = vec![
"+------------------------------------+",
"| Float64(400420638.54) * Float64(1) |",
"+------------------------------------+",
"| 400420638.54 |",
"+------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}
8 changes: 4 additions & 4 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,14 @@ order by s_name;
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value
Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")]
Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#
Expand Down Expand Up @@ -443,7 +443,7 @@ order by value desc;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: value DESC NULLS FIRST
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
CrossJoin:
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Expand All @@ -452,7 +452,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,7 @@ impl Subquery {
pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
match plan {
Expr::ScalarSubquery(it) => Ok(it),
Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()),
_ => plan_err!("Could not coerce into ScalarSubquery!"),
}
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ fn mathematics_numerical_coercion(
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
Some(dec_type.clone())
}
(Decimal128(_, _), Float32 | Float64) => Some(Float64),
(Float32 | Float64, Decimal128(_, _)) => Some(Float64),
(Decimal128(_, _), _) => {
let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
match converted_decimal_type {
Expand Down
136 changes: 134 additions & 2 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2574,9 +2574,22 @@ mod tests {
let right_expr = if right.data_type().eq(&op_type) {
col("b", schema)?
} else {
try_cast(col("b", schema)?, schema, op_type)?
try_cast(col("b", schema)?, schema, op_type.clone())?
};
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);

let coerced_schema = Schema::new(vec![
Field::new(
schema.field(0).name(),
op_type.clone(),
schema.field(0).is_nullable(),
),
Field::new(
schema.field(1).name(),
op_type,
schema.field(1).is_nullable(),
),
]);
let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
Expand Down Expand Up @@ -2704,6 +2717,125 @@ mod tests {
Ok(())
}

#[test]
fn arithmetic_decimal_float_expr_test() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let value: i128 = 123;
let decimal_array = Arc::new(create_decimal_array(
&[
Some(value as i128), // 1.23
None,
Some((value - 1) as i128), // 1.22
Some((value + 1) as i128), // 1.24
],
10,
2,
)) as ArrayRef;
let float64_array = Arc::new(Float64Array::from(vec![
Some(123.0),
Some(122.0),
Some(123.0),
Some(124.0),
])) as ArrayRef;

// add: float64 array add decimal array
let expect = Arc::new(Float64Array::from(vec![
Some(124.23),
None,
Some(124.22),
Some(125.24),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Plus,
expect,
)
.unwrap();

// subtract: decimal array subtract float64 array
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let expect = Arc::new(Float64Array::from(vec![
Some(121.77),
None,
Some(121.78),
Some(122.76),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Minus,
expect,
)
.unwrap();

// multiply: decimal array multiply float64 array
let expect = Arc::new(Float64Array::from(vec![
Some(151.29),
None,
Some(150.06),
Some(153.76),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Multiply,
expect,
)
.unwrap();

// divide: float64 array divide decimal array
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let expect = Arc::new(Float64Array::from(vec![
Some(100.0),
None,
Some(100.81967213114754),
Some(100.0),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Divide,
expect,
)
.unwrap();

// modulus: float64 array modulus decimal array
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Decimal128(10, 2), true),
]));
let expect = Arc::new(Float64Array::from(vec![
Some(1.7763568394002505e-15),
None,
Some(1.0000000000000027),
Some(8.881784197001252e-16),
])) as ArrayRef;
apply_arithmetic_op(
&schema,
&float64_array,
&decimal_array,
Operator::Modulo,
expect,
)
.unwrap();

Ok(())
}

#[test]
fn bitwise_array_test() -> Result<()> {
let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
Expand Down