@@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
19
19
20
20
import org .apache .spark .sql .catalyst .dsl .expressions ._
21
21
import org .apache .spark .sql .catalyst .dsl .plans ._
22
- import org .apache .spark .sql .catalyst .expressions .{GetStructField , MultiScalarSubquery , ScalarSubquery }
22
+ import org .apache .spark .sql .catalyst .expressions .{CreateStruct , GetStructField , ScalarSubquery }
23
23
import org .apache .spark .sql .catalyst .expressions .aggregate .{CollectList , CollectSet }
24
24
import org .apache .spark .sql .catalyst .plans ._
25
25
import org .apache .spark .sql .catalyst .plans .logical ._
26
26
import org .apache .spark .sql .catalyst .rules ._
27
27
28
28
class MergeScalarSubqueriesSuite extends PlanTest {
29
-
30
29
private object Optimize extends RuleExecutor [LogicalPlan ] {
31
30
val batches = Batch (" MergeScalarSubqueries" , Once , MergeScalarSubqueries ) :: Nil
32
31
}
@@ -35,82 +34,81 @@ class MergeScalarSubqueriesSuite extends PlanTest {
35
34
36
35
test(" Simple non-correlated scalar subquery merge" ) {
37
36
val subquery1 = testRelation
38
- .groupBy(' b )(max(' a ))
37
+ .groupBy(' b )(max(' a ).as( " max_a " ) )
39
38
val subquery2 = testRelation
40
- .groupBy(' b )(sum(' a ))
39
+ .groupBy(' b )(sum(' a ).as( " sum_a " ) )
41
40
val originalQuery = testRelation
42
41
.select(ScalarSubquery (subquery1), ScalarSubquery (subquery2))
43
42
44
43
val multiSubquery = testRelation
45
- .groupBy(' b )(max(' a ), sum(' a )).analyze
44
+ .groupBy(' b )(max(' a ).as(" max_a" ), sum(' a ).as(" sum_a" ))
45
+ .select(CreateStruct (Seq (' max_a , ' sum_a )).as(" mergedValue" ))
46
46
val correctAnswer = testRelation
47
- .select(GetStructField (MultiScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
48
- GetStructField (MultiScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
47
+ .select(GetStructField (ScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
48
+ GetStructField (ScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
49
49
50
- // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
51
- // as only `Optimizer` can insert such a node to the plan
52
- comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer, false )
50
+ comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer.analyze)
53
51
}
54
52
55
53
test(" Aggregate and group expression merge" ) {
56
54
val subquery1 = testRelation
57
- .groupBy(' b )(max(' a ))
55
+ .groupBy(' b )(max(' a ).as( " max_a " ) )
58
56
val subquery2 = testRelation
59
57
.groupBy(' b )(' b )
60
58
val originalQuery = testRelation
61
59
.select(ScalarSubquery (subquery1), ScalarSubquery (subquery2))
62
60
63
61
val multiSubquery = testRelation
64
- .groupBy(' b )(max(' a ), ' b ).analyze
62
+ .groupBy(' b )(max(' a ).as(" max_a" ), ' b )
63
+ .select(CreateStruct (Seq (' max_a , ' b )).as(" mergedValue" ))
65
64
val correctAnswer = testRelation
66
- .select(GetStructField (MultiScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
67
- GetStructField (MultiScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
65
+ .select(GetStructField (ScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
66
+ GetStructField (ScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
68
67
69
- // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
70
- // as only `Optimizer` can insert such a node to the plan
71
- comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer, false )
68
+ comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer.analyze)
72
69
}
73
70
74
71
test(" Do not merge different aggregate implementations" ) {
75
72
// supports HashAggregate
76
73
val subquery1 = testRelation
77
- .groupBy(' b )(max(' a ))
74
+ .groupBy(' b )(max(' a ).as( " max_a " ) )
78
75
val subquery2 = testRelation
79
- .groupBy(' b )(min(' a ))
76
+ .groupBy(' b )(min(' a ).as( " min_a " ) )
80
77
81
78
// supports ObjectHashAggregate
82
79
val subquery3 = testRelation
83
- .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ))
80
+ .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ).as( " collectlist_a " ) )
84
81
val subquery4 = testRelation
85
- .groupBy(' b )(CollectSet (' a ).toAggregateExpression(isDistinct = false ))
82
+ .groupBy(' b )(CollectSet (' a ).toAggregateExpression(isDistinct = false ).as( " collectset_a " ) )
86
83
87
84
// supports SortAggregate
88
85
val subquery5 = testRelation
89
- .groupBy(' b )(max(' c ))
86
+ .groupBy(' b )(max(' c ).as( " max_c " ) )
90
87
val subquery6 = testRelation
91
- .groupBy(' b )(min(' c ))
88
+ .groupBy(' b )(min(' c ).as( " min_c " ) )
92
89
93
90
val originalQuery = testRelation
94
91
.select(ScalarSubquery (subquery1), ScalarSubquery (subquery2), ScalarSubquery (subquery3),
95
92
ScalarSubquery (subquery4), ScalarSubquery (subquery5), ScalarSubquery (subquery6))
96
93
97
94
val hashAggregates = testRelation
98
- .groupBy(' b )(max(' a ), min(' a )).analyze
95
+ .groupBy(' b )(max(' a ).as(" max_a" ), min(' a ).as(" min_a" ))
96
+ .select(CreateStruct (Seq (' max_a , ' min_a )).as(" mergedValue" ))
99
97
val objectHashAggregates = testRelation
100
- .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ),
101
- CollectSet (' a ).toAggregateExpression(isDistinct = false )).analyze
98
+ .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ).as(" collectlist_a" ),
99
+ CollectSet (' a ).toAggregateExpression(isDistinct = false ).as(" collectset_a" ))
100
+ .select(CreateStruct (Seq (' collectlist_a , ' collectset_a )).as(" mergedValue" ))
102
101
val sortAggregates = testRelation
103
- .groupBy(' b )(max(' c ), min(' c )).analyze
102
+ .groupBy(' b )(max(' c ).as(" max_c" ), min(' c ).as(" min_c" ))
103
+ .select(CreateStruct (Seq (' max_c , ' min_c )).as(" mergedValue" ))
104
104
val correctAnswer = testRelation
105
- .select(GetStructField (MultiScalarSubquery (hashAggregates), 0 ).as(" scalarsubquery()" ),
106
- GetStructField (MultiScalarSubquery (hashAggregates), 1 ).as(" scalarsubquery()" ),
107
- GetStructField (MultiScalarSubquery (objectHashAggregates), 0 ).as(" scalarsubquery()" ),
108
- GetStructField (MultiScalarSubquery (objectHashAggregates), 1 ).as(" scalarsubquery()" ),
109
- GetStructField (MultiScalarSubquery (sortAggregates), 0 ).as(" scalarsubquery()" ),
110
- GetStructField (MultiScalarSubquery (sortAggregates), 1 ).as(" scalarsubquery()" ))
111
-
112
- // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
113
- // as only `Optimizer` can insert such a node to the plan
114
- comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer, false )
105
+ .select(GetStructField (ScalarSubquery (hashAggregates), 0 ).as(" scalarsubquery()" ),
106
+ GetStructField (ScalarSubquery (hashAggregates), 1 ).as(" scalarsubquery()" ),
107
+ GetStructField (ScalarSubquery (objectHashAggregates), 0 ).as(" scalarsubquery()" ),
108
+ GetStructField (ScalarSubquery (objectHashAggregates), 1 ).as(" scalarsubquery()" ),
109
+ GetStructField (ScalarSubquery (sortAggregates), 0 ).as(" scalarsubquery()" ),
110
+ GetStructField (ScalarSubquery (sortAggregates), 1 ).as(" scalarsubquery()" ))
111
+
112
+ comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer.analyze)
115
113
}
116
114
}
0 commit comments