@@ -37,17 +37,18 @@ class RewriteSubquerySuite extends PlanTest {
37
37
InferFiltersFromConstraints ,
38
38
PushDownPredicate ,
39
39
CollapseProject ,
40
+ CombineFilters ,
40
41
RemoveRedundantProject ) :: Nil
41
42
}
42
43
44
+ val relation = LocalRelation (' a .int, ' b .int)
45
+ val relInSubquery = LocalRelation (' x .int, ' y .int, ' z .int)
46
+
43
47
test(" Column pruning after rewriting predicate subquery" ) {
44
48
withSQLConf(SQLConf .CONSTRAINT_PROPAGATION_ENABLED .key -> " false" ) {
45
- val relation = LocalRelation (' a .int, ' b .int)
46
- val relInSubquery = LocalRelation (' x .int, ' y .int, ' z .int)
47
-
48
49
val query = relation.where(' a .in(ListQuery (relInSubquery.select(' x )))).select(' a )
49
-
50
50
val optimized = Optimize .execute(query.analyze)
51
+
51
52
val correctAnswer = relation
52
53
.select(' a )
53
54
.join(relInSubquery.select(' x ), LeftSemi , Some (' a === ' x ))
@@ -59,12 +60,9 @@ class RewriteSubquerySuite extends PlanTest {
59
60
60
61
test(" Infer filters and push down predicate after rewriting predicate subquery" ) {
61
62
withSQLConf(SQLConf .CONSTRAINT_PROPAGATION_ENABLED .key -> " true" ) {
62
- val relation = LocalRelation (' a .int, ' b .int)
63
- val relInSubquery = LocalRelation (' x .int, ' y .int, ' z .int)
64
-
65
63
val query = relation.where(' a .in(ListQuery (relInSubquery.select(' x )))).select(' a )
66
-
67
64
val optimized = Optimize .execute(query.analyze)
65
+
68
66
val correctAnswer = relation
69
67
.where(IsNotNull (' a )).select(' a )
70
68
.join(relInSubquery.where(IsNotNull (' x )).select(' x ), LeftSemi , Some (' a === ' x ))
@@ -74,4 +72,17 @@ class RewriteSubquerySuite extends PlanTest {
74
72
}
75
73
}
76
74
75
+ test(" combine filters after rewriting predicate subquery" ) {
76
+ val query = relation.where(' a .in(ListQuery (relInSubquery.select(' x ).where(' y > 1 )))).select(' a )
77
+ val optimized = Optimize .execute(query.analyze)
78
+
79
+ val correctAnswer = relation
80
+ .where(IsNotNull (' a )).select(' a )
81
+ .join(relInSubquery.where(IsNotNull (' x ) && IsNotNull (' y ) && ' y > 1 ).select(' x ),
82
+ LeftSemi , Some (' a === ' x ))
83
+ .analyze
84
+
85
+ comparePlans(optimized, correctAnswer)
86
+ }
87
+
77
88
}
0 commit comments