17
17
18
18
package org .apache .spark .sql .catalyst .optimizer
19
19
20
+ import org .apache .spark .sql .catalyst .SimpleCatalystConf
21
+ import org .apache .spark .sql .catalyst .analysis .{Analyzer , EmptyFunctionRegistry }
22
+ import org .apache .spark .sql .catalyst .catalog .{InMemoryCatalog , SessionCatalog }
20
23
import org .apache .spark .sql .catalyst .dsl .expressions ._
21
24
import org .apache .spark .sql .catalyst .dsl .plans ._
22
25
import org .apache .spark .sql .catalyst .expressions .Literal
@@ -25,10 +28,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25
28
import org .apache .spark .sql .catalyst .rules .RuleExecutor
26
29
27
30
class AggregateOptimizeSuite extends PlanTest {
31
+ val conf = new SimpleCatalystConf (caseSensitiveAnalysis = false )
32
+ val catalog = new SessionCatalog (new InMemoryCatalog , EmptyFunctionRegistry , conf)
33
+ val analyzer = new Analyzer (catalog, conf)
28
34
29
35
object Optimize extends RuleExecutor [LogicalPlan ] {
30
36
val batches = Batch (" Aggregate" , FixedPoint (100 ),
31
- RemoveLiteralFromGroupExpressions ) :: Nil
37
+ RemoveLiteralFromGroupExpressions ,
38
+ RemoveRepetitionFromGroupExpressions ) :: Nil
32
39
}
33
40
34
41
test(" remove literals in grouping expression" ) {
@@ -42,4 +49,15 @@ class AggregateOptimizeSuite extends PlanTest {
42
49
43
50
comparePlans(optimized, correctAnswer)
44
51
}
52
+
53
+ test(" remove repetition in grouping expression" ) {
54
+ val input = LocalRelation (' a .int, ' b .int, ' c .int)
55
+
56
+ val query = input.groupBy(' a + 1 , ' b + 2 , Literal (1 ) + ' A , Literal (2 ) + ' B )(sum(' c ))
57
+ val optimized = Optimize .execute(analyzer.execute(query))
58
+
59
+ val correctAnswer = analyzer.execute(input.groupBy(' a + 1 , ' b + 2 )(sum(' c )))
60
+
61
+ comparePlans(optimized, correctAnswer)
62
+ }
45
63
}
0 commit comments