Skip to content

Commit f3bb84f

Browse files
authored
inlist: move type coercion to logical phase (#3472)
1 parent 0388682 commit f3bb84f

File tree

4 files changed

+171
-69
lines changed

4 files changed

+171
-69
lines changed

datafusion/core/src/physical_plan/planner.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,9 +1696,11 @@ mod tests {
16961696
async fn plan(logical_plan: &LogicalPlan) -> Result<Arc<dyn ExecutionPlan>> {
16971697
let mut session_state = make_session_state();
16981698
session_state.config.target_partitions = 4;
1699+
// optimize the logical plan
1700+
let logical_plan = session_state.optimize(logical_plan)?;
16991701
let planner = DefaultPhysicalPlanner::default();
17001702
planner
1701-
.create_physical_plan(logical_plan, &session_state)
1703+
.create_physical_plan(&logical_plan, &session_state)
17021704
.await
17031705
}
17041706

@@ -1714,12 +1716,12 @@ mod tests {
17141716
.limit(3, Some(10))?
17151717
.build()?;
17161718

1717-
let plan = plan(&logical_plan).await?;
1719+
let exec_plan = plan(&logical_plan).await?;
17181720

17191721
// verify that the plan correctly casts u8 to i64
17201722
// the cast here is implicit so has CastOptions with safe=true
1721-
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 6 }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }";
1722-
assert!(format!("{:?}", plan).contains(expected));
1723+
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }";
1724+
assert!(format!("{:?}", exec_plan).contains(expected));
17231725

17241726
Ok(())
17251727
}
@@ -1821,8 +1823,7 @@ mod tests {
18211823
async fn test_with_zero_offset_plan() -> Result<()> {
18221824
let logical_plan = test_csv_scan().await?.limit(0, None)?.build()?;
18231825
let plan = plan(&logical_plan).await?;
1824-
assert!(format!("{:?}", plan).contains("GlobalLimitExec"));
1825-
assert!(format!("{:?}", plan).contains("skip: 0"));
1826+
assert!(format!("{:?}", plan).contains("limit: None"));
18261827
Ok(())
18271828
}
18281829

@@ -1952,8 +1953,8 @@ mod tests {
19521953
.project(vec![col("c1").in_list(list, false)])?
19531954
.build()?;
19541955
let execution_plan = plan(&logical_plan).await?;
1955-
// verify that the plan correctly adds cast from Int64(1) to Utf8
1956-
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }], negated: false, set: None }";
1956+
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
1957+
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false, set: None }";
19571958
assert!(format!("{:?}", execution_plan).contains(expected));
19581959

19591960
// expression: "a in (struct::null, 'a')"
@@ -1965,10 +1966,9 @@ mod tests {
19651966
.filter(col("c12").lt(lit(0.05)))?
19661967
.project(vec![col("c12").lt_eq(lit(0.025)).in_list(list, false)])?
19671968
.build()?;
1968-
let execution_plan = plan(&logical_plan).await;
1969+
let e = plan(&logical_plan).await.unwrap_err().to_string();
19691970

1970-
let e = execution_plan.unwrap_err().to_string();
1971-
assert_contains!(&e, "Can not find compatible types to compare Boolean with [Struct([Field { name: \"foo\", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }]), Utf8]");
1971+
assert_contains!(&e, "The data type inlist should be same, the value type is Boolean, one of list expr type is Struct([Field { name: \"foo\", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }])");
19721972

19731973
Ok(())
19741974
}
@@ -1996,7 +1996,10 @@ mod tests {
19961996
.project(vec![col("c1").in_list(list, false)])?
19971997
.build()?;
19981998
let execution_plan = plan(&logical_plan).await?;
1999-
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8 }], negated: false, set: Some(InSet { set: ";
1999+
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }, Literal { value: Utf8(\"2\") },";
2000+
assert!(format!("{:?}", execution_plan).contains(expected));
2001+
let expected =
2002+
"Literal { value: Utf8(\"30\") }], negated: false, set: Some(InSet { set: ";
20002003
assert!(format!("{:?}", execution_plan).contains(expected));
20012004
Ok(())
20022005
}
@@ -2015,7 +2018,10 @@ mod tests {
20152018
.project(vec![col("c1").in_list(list, false)])?
20162019
.build()?;
20172020
let execution_plan = plan(&logical_plan).await?;
2018-
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(6) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(7) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(8) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(9) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(10) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(11) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(12) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(13) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(14) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(15) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(16) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(17) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(18) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(19) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(20) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(21) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(22) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(23) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(24) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(25) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(26) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(27) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(28) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(29) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(30) }, cast_type: Utf8 }], negated: false, set: Some(InSet {";
2021+
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(NULL) }, Literal { value: Utf8(\"1\") }, Literal { value: Utf8(\"2\") }";
2022+
assert!(format!("{:?}", execution_plan).contains(expected));
2023+
let expected =
2024+
"Literal { value: Utf8(\"30\") }], negated: false, set: Some(InSet";
20192025
assert!(format!("{:?}", execution_plan).contains(expected));
20202026
Ok(())
20212027
}

datafusion/optimizer/src/type_coercion.rs

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ impl ExprRewriter for TypeCoercionRewriter<'_> {
124124
right.clone().cast_to(&coerced_type, &self.schema)?,
125125
),
126126
};
127-
128127
expr.rewrite(&mut self.const_evaluator)
129128
}
130129
}
@@ -164,11 +163,61 @@ impl ExprRewriter for TypeCoercionRewriter<'_> {
164163
};
165164
expr.rewrite(&mut self.const_evaluator)
166165
}
166+
Expr::InList {
167+
expr,
168+
list,
169+
negated,
170+
} => {
171+
let expr_data_type = expr.get_type(&self.schema)?;
172+
let list_data_types = list
173+
.iter()
174+
.map(|list_expr| list_expr.get_type(&self.schema))
175+
.collect::<Result<Vec<_>>>()?;
176+
let result_type =
177+
get_coerce_type_for_list(&expr_data_type, &list_data_types);
178+
match result_type {
179+
None => Err(DataFusionError::Plan(format!(
180+
"Can not find compatible types to compare {:?} with {:?}",
181+
expr_data_type, list_data_types
182+
))),
183+
Some(coerced_type) => {
184+
// find the coerced type
185+
let cast_expr = expr.cast_to(&coerced_type, &self.schema)?;
186+
let cast_list_expr = list
187+
.into_iter()
188+
.map(|list_expr| {
189+
list_expr.cast_to(&coerced_type, &self.schema)
190+
})
191+
.collect::<Result<Vec<_>>>()?;
192+
let expr = Expr::InList {
193+
expr: Box::new(cast_expr),
194+
list: cast_list_expr,
195+
negated,
196+
};
197+
expr.rewrite(&mut self.const_evaluator)
198+
}
199+
}
200+
}
167201
expr => Ok(expr),
168202
}
169203
}
170204
}
171205

206+
/// Attempts to coerce the types of `list_types` to be comparable with the
207+
/// `expr_type`.
208+
/// Returns the common data type for `expr_type` and `list_types`
209+
fn get_coerce_type_for_list(
210+
expr_type: &DataType,
211+
list_types: &[DataType],
212+
) -> Option<DataType> {
213+
list_types
214+
.iter()
215+
.fold(Some(expr_type.clone()), |left, right_type| match left {
216+
None => None,
217+
Some(left_type) => comparison_coercion(&left_type, right_type),
218+
})
219+
}
220+
172221
/// Returns `expressions` coerced to types compatible with
173222
/// `signature`, if possible.
174223
///
@@ -348,6 +397,49 @@ mod test {
348397
Ok(())
349398
}
350399

400+
#[test]
401+
fn inlist_case() -> Result<()> {
402+
// a in (1,4,8), a is int64
403+
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
404+
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
405+
produce_one_row: false,
406+
schema: Arc::new(
407+
DFSchema::new_with_metadata(
408+
vec![DFField::new(None, "a", DataType::Int64, true)],
409+
std::collections::HashMap::new(),
410+
)
411+
.unwrap(),
412+
),
413+
}));
414+
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
415+
let rule = TypeCoercion::new();
416+
let mut config = OptimizerConfig::default();
417+
let plan = rule.optimize(&plan, &mut config)?;
418+
assert_eq!(
419+
"Projection: #a IN ([Int64(1), Int64(4), Int64(8)])\n EmptyRelation",
420+
&format!("{:?}", plan)
421+
);
422+
// a in (1,4,8), a is decimal
423+
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
424+
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
425+
produce_one_row: false,
426+
schema: Arc::new(
427+
DFSchema::new_with_metadata(
428+
vec![DFField::new(None, "a", DataType::Decimal128(12, 4), true)],
429+
std::collections::HashMap::new(),
430+
)
431+
.unwrap(),
432+
),
433+
}));
434+
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
435+
let plan = rule.optimize(&plan, &mut config)?;
436+
assert_eq!(
437+
"Projection: CAST(#a AS Decimal128(24, 4)) IN ([Decimal128(Some(10000),24,4), Decimal128(Some(40000),24,4), Decimal128(Some(80000),24,4)])\n EmptyRelation",
438+
&format!("{:?}", plan)
439+
);
440+
Ok(())
441+
}
442+
351443
fn empty() -> Arc<LogicalPlan> {
352444
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
353445
produce_one_row: false,

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,17 @@ pub fn in_list(
960960
negated: &bool,
961961
schema: &Schema,
962962
) -> Result<Arc<dyn PhysicalExpr>> {
963+
// check the data type
964+
let expr_data_type = expr.data_type(schema)?;
965+
for list_expr in list.iter() {
966+
let list_expr_data_type = list_expr.data_type(schema)?;
967+
if !expr_data_type.eq(&list_expr_data_type) {
968+
return Err(DataFusionError::Internal(format!(
969+
"The data type inlist should be same, the value type is {}, one of list expr type is {}",
970+
expr_data_type, list_expr_data_type
971+
)));
972+
}
973+
}
963974
Ok(Arc::new(InListExpr::new(expr, list, *negated, schema)))
964975
}
965976

@@ -969,9 +980,54 @@ mod tests {
969980

970981
use super::*;
971982
use crate::expressions;
972-
use crate::expressions::{col, lit};
973-
use crate::planner::in_list_cast;
983+
use crate::expressions::{col, lit, try_cast};
974984
use datafusion_common::Result;
985+
use datafusion_expr::binary_rule::comparison_coercion;
986+
987+
type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
988+
989+
// Try to do the type coercion for list physical expr.
990+
// It's just used in the test
991+
fn in_list_cast(
992+
expr: Arc<dyn PhysicalExpr>,
993+
list: Vec<Arc<dyn PhysicalExpr>>,
994+
input_schema: &Schema,
995+
) -> Result<InListCastResult> {
996+
let expr_type = &expr.data_type(input_schema)?;
997+
let list_types: Vec<DataType> = list
998+
.iter()
999+
.map(|list_expr| list_expr.data_type(input_schema).unwrap())
1000+
.collect();
1001+
let result_type = get_coerce_type(expr_type, &list_types);
1002+
match result_type {
1003+
None => Err(DataFusionError::Plan(format!(
1004+
"Can not find compatible types to compare {:?} with {:?}",
1005+
expr_type, list_types
1006+
))),
1007+
Some(data_type) => {
1008+
// find the coerced type
1009+
let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
1010+
let cast_list_expr = list
1011+
.into_iter()
1012+
.map(|list_expr| {
1013+
try_cast(list_expr, input_schema, data_type.clone()).unwrap()
1014+
})
1015+
.collect();
1016+
Ok((cast_expr, cast_list_expr))
1017+
}
1018+
}
1019+
}
1020+
1021+
// Attempts to coerce the types of `list_type` to be comparable with the
1022+
// `expr_type`
1023+
fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
1024+
list_type
1025+
.iter()
1026+
.fold(Some(expr_type.clone()), |left, right_type| match left {
1027+
None => None,
1028+
Some(left_type) => comparison_coercion(&left_type, right_type),
1029+
})
1030+
}
9751031

9761032
// applies the in_list expr to an input batch and list
9771033
macro_rules! in_list {

datafusion/physical-expr/src/planner.rs

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::expressions::try_cast;
1918
use crate::var_provider::is_system_variables;
2019
use crate::{
2120
execution_props::ExecutionProps,
@@ -28,7 +27,6 @@ use crate::{
2827
};
2928
use arrow::datatypes::{DataType, Schema};
3029
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
31-
use datafusion_expr::binary_rule::comparison_coercion;
3230
use datafusion_expr::{binary_expr, Expr, Operator};
3331
use std::sync::Arc;
3432

@@ -410,10 +408,7 @@ pub fn create_physical_expr(
410408
)
411409
})
412410
.collect::<Result<Vec<_>>>()?;
413-
414-
let (cast_expr, cast_list_exprs) =
415-
in_list_cast(value_expr, list_exprs, input_schema)?;
416-
expressions::in_list(cast_expr, cast_list_exprs, negated, input_schema)
411+
expressions::in_list(value_expr, list_exprs, negated, input_schema)
417412
}
418413
},
419414
other => Err(DataFusionError::NotImplemented(format!(
@@ -422,50 +417,3 @@ pub fn create_physical_expr(
422417
))),
423418
}
424419
}
425-
426-
type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
427-
428-
pub(crate) fn in_list_cast(
429-
expr: Arc<dyn PhysicalExpr>,
430-
list: Vec<Arc<dyn PhysicalExpr>>,
431-
input_schema: &Schema,
432-
) -> Result<InListCastResult> {
433-
let expr_type = &expr.data_type(input_schema)?;
434-
let list_types: Vec<DataType> = list
435-
.iter()
436-
.map(|list_expr| list_expr.data_type(input_schema).unwrap())
437-
.collect();
438-
let result_type = get_coerce_type(expr_type, &list_types);
439-
match result_type {
440-
None => Err(DataFusionError::Plan(format!(
441-
"Can not find compatible types to compare {:?} with {:?}",
442-
expr_type, list_types
443-
))),
444-
Some(data_type) => {
445-
// find the coerced type
446-
let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
447-
let cast_list_expr = list
448-
.into_iter()
449-
.map(|list_expr| {
450-
try_cast(list_expr, input_schema, data_type.clone()).unwrap()
451-
})
452-
.collect();
453-
Ok((cast_expr, cast_list_expr))
454-
}
455-
}
456-
}
457-
458-
/// Attempts to coerce the types of `list_type` to be comparable with the
459-
/// `expr_type`
460-
fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
461-
// get the equal coerced data type
462-
list_type
463-
.iter()
464-
.fold(Some(expr_type.clone()), |left, right_type| {
465-
match left {
466-
None => None,
467-
// TODO refactor a framework to do the data type coercion
468-
Some(left_type) => comparison_coercion(&left_type, right_type),
469-
}
470-
})
471-
}

0 commit comments

Comments
 (0)