Skip to content

Commit 41effc4

Browse files
authored
Encapsulate ProjectionMapping as a struct (#8033)
1 parent c2e7680 commit 41effc4

File tree

5 files changed

+309
-283
lines changed

5 files changed

+309
-283
lines changed

datafusion/physical-expr/src/equivalence.rs

Lines changed: 300 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,65 @@ use indexmap::IndexMap;
4242
pub type EquivalenceClass = Vec<Arc<dyn PhysicalExpr>>;
4343

4444
/// Stores the mapping between source expressions and target expressions for a
45-
/// projection. Indices in the vector corresponds to the indices after projection.
46-
pub type ProjectionMapping = Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>;
45+
/// projection.
46+
#[derive(Debug, Clone)]
47+
pub struct ProjectionMapping {
48+
/// `(source expression)` --> `(target expression)`
49+
/// Indices in the vector corresponds to the indices after projection.
50+
inner: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
51+
}
52+
53+
impl ProjectionMapping {
54+
/// Constructs the mapping between a projection's input and output
55+
/// expressions.
56+
///
57+
/// For example, given the input projection expressions (`a+b`, `c+d`)
58+
/// and an output schema with two columns `"c+d"` and `"a+b"`
59+
/// the projection mapping would be
60+
/// ```text
61+
/// [0]: (c+d, col("c+d"))
62+
/// [1]: (a+b, col("a+b"))
63+
/// ```
64+
/// where `col("c+d")` means the column named "c+d".
65+
pub fn try_new(
66+
expr: &[(Arc<dyn PhysicalExpr>, String)],
67+
input_schema: &SchemaRef,
68+
) -> Result<Self> {
69+
// Construct a map from the input expressions to the output expression of the projection:
70+
let mut inner = vec![];
71+
for (expr_idx, (expression, name)) in expr.iter().enumerate() {
72+
let target_expr = Arc::new(Column::new(name, expr_idx)) as _;
73+
74+
let source_expr = expression.clone().transform_down(&|e| match e
75+
.as_any()
76+
.downcast_ref::<Column>(
77+
) {
78+
Some(col) => {
79+
// Sometimes, expression and its name in the input_schema doesn't match.
80+
// This can cause problems. Hence in here we make sure that expression name
81+
// matches with the name in the inout_schema.
82+
// Conceptually, source_expr and expression should be same.
83+
let idx = col.index();
84+
let matching_input_field = input_schema.field(idx);
85+
let matching_input_column =
86+
Column::new(matching_input_field.name(), idx);
87+
Ok(Transformed::Yes(Arc::new(matching_input_column)))
88+
}
89+
None => Ok(Transformed::No(e)),
90+
})?;
91+
92+
inner.push((source_expr, target_expr));
93+
}
94+
Ok(Self { inner })
95+
}
96+
97+
/// Iterate over pairs of (source, target) expressions
98+
pub fn iter(
99+
&self,
100+
) -> impl Iterator<Item = &(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> + '_ {
101+
self.inner.iter()
102+
}
103+
}
47104

48105
/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each
49106
/// class represents a distinct equivalence class in a relation.
@@ -329,7 +386,7 @@ impl EquivalenceGroup {
329386
// once `Arc<dyn PhysicalExpr>` can be stored in `HashMap`.
330387
// See issue: https://github.com/apache/arrow-datafusion/issues/8027
331388
let mut new_classes = vec![];
332-
for (source, target) in mapping {
389+
for (source, target) in mapping.iter() {
333390
if new_classes.is_empty() {
334391
new_classes.push((source, vec![target.clone()]));
335392
}
@@ -980,7 +1037,7 @@ impl EquivalenceProperties {
9801037
.iter()
9811038
.filter_map(|order| self.eq_group.project_ordering(projection_mapping, order))
9821039
.collect::<Vec<_>>();
983-
for (source, target) in projection_mapping {
1040+
for (source, target) in projection_mapping.iter() {
9841041
let expr_ordering = ExprOrdering::new(source.clone())
9851042
.transform_up(&|expr| update_ordering(expr, self))
9861043
.unwrap();
@@ -1185,6 +1242,7 @@ fn update_ordering(
11851242

11861243
#[cfg(test)]
11871244
mod tests {
1245+
use std::ops::Not;
11881246
use std::sync::Arc;
11891247

11901248
use super::*;
@@ -1437,12 +1495,14 @@ mod tests {
14371495
let col_a2 = &col("a2", &out_schema)?;
14381496
let col_a3 = &col("a3", &out_schema)?;
14391497
let col_a4 = &col("a4", &out_schema)?;
1440-
let projection_mapping = vec![
1441-
(col_a.clone(), col_a1.clone()),
1442-
(col_a.clone(), col_a2.clone()),
1443-
(col_a.clone(), col_a3.clone()),
1444-
(col_a.clone(), col_a4.clone()),
1445-
];
1498+
let projection_mapping = ProjectionMapping {
1499+
inner: vec![
1500+
(col_a.clone(), col_a1.clone()),
1501+
(col_a.clone(), col_a2.clone()),
1502+
(col_a.clone(), col_a3.clone()),
1503+
(col_a.clone(), col_a4.clone()),
1504+
],
1505+
};
14461506
let out_properties = input_properties.project(&projection_mapping, out_schema);
14471507

14481508
// At the output a1=a2=a3=a4
@@ -2565,4 +2625,234 @@ mod tests {
25652625

25662626
Ok(())
25672627
}
2628+
2629+
#[test]
2630+
fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> {
2631+
let sort_options = SortOptions::default();
2632+
let sort_options_not = SortOptions::default().not();
2633+
2634+
let schema = Schema::new(vec![
2635+
Field::new("a", DataType::Int32, true),
2636+
Field::new("b", DataType::Int32, true),
2637+
]);
2638+
let col_a = &col("a", &schema)?;
2639+
let col_b = &col("b", &schema)?;
2640+
let required_columns = [col_b.clone(), col_a.clone()];
2641+
let mut eq_properties = EquivalenceProperties::new(Arc::new(schema));
2642+
eq_properties.add_new_orderings([vec![
2643+
PhysicalSortExpr {
2644+
expr: Arc::new(Column::new("b", 1)),
2645+
options: sort_options_not,
2646+
},
2647+
PhysicalSortExpr {
2648+
expr: Arc::new(Column::new("a", 0)),
2649+
options: sort_options,
2650+
},
2651+
]]);
2652+
let (result, idxs) = eq_properties.find_longest_permutation(&required_columns);
2653+
assert_eq!(idxs, vec![0, 1]);
2654+
assert_eq!(
2655+
result,
2656+
vec![
2657+
PhysicalSortExpr {
2658+
expr: col_b.clone(),
2659+
options: sort_options_not
2660+
},
2661+
PhysicalSortExpr {
2662+
expr: col_a.clone(),
2663+
options: sort_options
2664+
}
2665+
]
2666+
);
2667+
2668+
let schema = Schema::new(vec![
2669+
Field::new("a", DataType::Int32, true),
2670+
Field::new("b", DataType::Int32, true),
2671+
Field::new("c", DataType::Int32, true),
2672+
]);
2673+
let col_a = &col("a", &schema)?;
2674+
let col_b = &col("b", &schema)?;
2675+
let required_columns = [col_b.clone(), col_a.clone()];
2676+
let mut eq_properties = EquivalenceProperties::new(Arc::new(schema));
2677+
eq_properties.add_new_orderings([
2678+
vec![PhysicalSortExpr {
2679+
expr: Arc::new(Column::new("c", 2)),
2680+
options: sort_options,
2681+
}],
2682+
vec![
2683+
PhysicalSortExpr {
2684+
expr: Arc::new(Column::new("b", 1)),
2685+
options: sort_options_not,
2686+
},
2687+
PhysicalSortExpr {
2688+
expr: Arc::new(Column::new("a", 0)),
2689+
options: sort_options,
2690+
},
2691+
],
2692+
]);
2693+
let (result, idxs) = eq_properties.find_longest_permutation(&required_columns);
2694+
assert_eq!(idxs, vec![0, 1]);
2695+
assert_eq!(
2696+
result,
2697+
vec![
2698+
PhysicalSortExpr {
2699+
expr: col_b.clone(),
2700+
options: sort_options_not
2701+
},
2702+
PhysicalSortExpr {
2703+
expr: col_a.clone(),
2704+
options: sort_options
2705+
}
2706+
]
2707+
);
2708+
2709+
let required_columns = [
2710+
Arc::new(Column::new("b", 1)) as _,
2711+
Arc::new(Column::new("a", 0)) as _,
2712+
];
2713+
let schema = Schema::new(vec![
2714+
Field::new("a", DataType::Int32, true),
2715+
Field::new("b", DataType::Int32, true),
2716+
Field::new("c", DataType::Int32, true),
2717+
]);
2718+
let mut eq_properties = EquivalenceProperties::new(Arc::new(schema));
2719+
2720+
// not satisfied orders
2721+
eq_properties.add_new_orderings([vec![
2722+
PhysicalSortExpr {
2723+
expr: Arc::new(Column::new("b", 1)),
2724+
options: sort_options_not,
2725+
},
2726+
PhysicalSortExpr {
2727+
expr: Arc::new(Column::new("c", 2)),
2728+
options: sort_options,
2729+
},
2730+
PhysicalSortExpr {
2731+
expr: Arc::new(Column::new("a", 0)),
2732+
options: sort_options,
2733+
},
2734+
]]);
2735+
let (_, idxs) = eq_properties.find_longest_permutation(&required_columns);
2736+
assert_eq!(idxs, vec![0]);
2737+
2738+
Ok(())
2739+
}
2740+
2741+
#[test]
2742+
fn test_normalize_ordering_equivalence_classes() -> Result<()> {
2743+
let sort_options = SortOptions::default();
2744+
2745+
let schema = Schema::new(vec![
2746+
Field::new("a", DataType::Int32, true),
2747+
Field::new("b", DataType::Int32, true),
2748+
Field::new("c", DataType::Int32, true),
2749+
]);
2750+
let col_a_expr = col("a", &schema)?;
2751+
let col_b_expr = col("b", &schema)?;
2752+
let col_c_expr = col("c", &schema)?;
2753+
let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone()));
2754+
2755+
eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr);
2756+
let others = vec![
2757+
vec![PhysicalSortExpr {
2758+
expr: col_b_expr.clone(),
2759+
options: sort_options,
2760+
}],
2761+
vec![PhysicalSortExpr {
2762+
expr: col_c_expr.clone(),
2763+
options: sort_options,
2764+
}],
2765+
];
2766+
eq_properties.add_new_orderings(others);
2767+
2768+
let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema));
2769+
expected_eqs.add_new_orderings([
2770+
vec![PhysicalSortExpr {
2771+
expr: col_b_expr.clone(),
2772+
options: sort_options,
2773+
}],
2774+
vec![PhysicalSortExpr {
2775+
expr: col_c_expr.clone(),
2776+
options: sort_options,
2777+
}],
2778+
]);
2779+
2780+
let oeq_class = eq_properties.oeq_class().clone();
2781+
let expected = expected_eqs.oeq_class();
2782+
assert!(oeq_class.eq(expected));
2783+
2784+
Ok(())
2785+
}
2786+
2787+
#[test]
2788+
fn project_empty_output_ordering() -> Result<()> {
2789+
let schema = Schema::new(vec![
2790+
Field::new("a", DataType::Int32, true),
2791+
Field::new("b", DataType::Int32, true),
2792+
Field::new("c", DataType::Int32, true),
2793+
]);
2794+
let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone()));
2795+
let ordering = vec![PhysicalSortExpr {
2796+
expr: Arc::new(Column::new("b", 1)),
2797+
options: SortOptions::default(),
2798+
}];
2799+
eq_properties.add_new_orderings([ordering]);
2800+
let projection_mapping = ProjectionMapping {
2801+
inner: vec![
2802+
(
2803+
Arc::new(Column::new("b", 1)) as _,
2804+
Arc::new(Column::new("b_new", 0)) as _,
2805+
),
2806+
(
2807+
Arc::new(Column::new("a", 0)) as _,
2808+
Arc::new(Column::new("a_new", 1)) as _,
2809+
),
2810+
],
2811+
};
2812+
let projection_schema = Arc::new(Schema::new(vec![
2813+
Field::new("b_new", DataType::Int32, true),
2814+
Field::new("a_new", DataType::Int32, true),
2815+
]));
2816+
let orderings = eq_properties
2817+
.project(&projection_mapping, projection_schema)
2818+
.oeq_class()
2819+
.output_ordering()
2820+
.unwrap_or_default();
2821+
2822+
assert_eq!(
2823+
vec![PhysicalSortExpr {
2824+
expr: Arc::new(Column::new("b_new", 0)),
2825+
options: SortOptions::default(),
2826+
}],
2827+
orderings
2828+
);
2829+
2830+
let schema = Schema::new(vec![
2831+
Field::new("a", DataType::Int32, true),
2832+
Field::new("b", DataType::Int32, true),
2833+
Field::new("c", DataType::Int32, true),
2834+
]);
2835+
let eq_properties = EquivalenceProperties::new(Arc::new(schema));
2836+
let projection_mapping = ProjectionMapping {
2837+
inner: vec![
2838+
(
2839+
Arc::new(Column::new("c", 2)) as _,
2840+
Arc::new(Column::new("c_new", 0)) as _,
2841+
),
2842+
(
2843+
Arc::new(Column::new("b", 1)) as _,
2844+
Arc::new(Column::new("b_new", 1)) as _,
2845+
),
2846+
],
2847+
};
2848+
let projection_schema = Arc::new(Schema::new(vec![
2849+
Field::new("c_new", DataType::Int32, true),
2850+
Field::new("b_new", DataType::Int32, true),
2851+
]));
2852+
let projected = eq_properties.project(&projection_mapping, projection_schema);
2853+
// After projection there is no ordering.
2854+
assert!(projected.oeq_class().output_ordering().is_none());
2855+
2856+
Ok(())
2857+
}
25682858
}

0 commit comments

Comments
 (0)