17
17
18
18
package org .apache .spark .sql .catalyst .optimizer
19
19
20
- import org .scalatest .Matchers ._
21
-
22
20
import org .apache .spark .api .python .PythonEvalType
23
21
import org .apache .spark .sql .AnalysisException
24
22
import org .apache .spark .sql .catalyst .dsl .expressions ._
@@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
28
26
import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan }
29
27
import org .apache .spark .sql .catalyst .rules .RuleExecutor
30
28
import org .apache .spark .sql .internal .SQLConf ._
31
- import org .apache .spark .sql .types .BooleanType
29
+ import org .apache .spark .sql .types .{ BooleanType , IntegerType }
32
30
33
31
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
34
32
@@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
40
38
CheckCartesianProducts ) :: Nil
41
39
}
42
40
43
- val testRelationLeft = LocalRelation (' a .int, ' b .int)
44
- val testRelationRight = LocalRelation (' c .int, ' d .int)
41
+ val attrA = ' a .int
42
+ val attrB = ' b .int
43
+ val attrC = ' c .int
44
+ val attrD = ' d .int
45
+
46
+ val testRelationLeft = LocalRelation (attrA, attrB)
47
+ val testRelationRight = LocalRelation (attrC, attrD)
48
+
49
+ // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
50
+ // refer to attributes from one side.
51
+ val evaluableJoinCond = {
52
+ val pythonUDF = PythonUDF (" evaluable" , null ,
53
+ IntegerType ,
54
+ Seq (attrA),
55
+ PythonEvalType .SQL_BATCHED_UDF ,
56
+ udfDeterministic = true )
57
+ pythonUDF === attrC
58
+ }
45
59
46
- // Dummy python UDF for testing. Unable to execute .
47
- val pythonUDF = PythonUDF (" pythonUDF " , null ,
60
+ // This join condition is a PythonUDF which refers to attributes from 2 tables .
61
+ val unevaluableJoinCond = PythonUDF (" unevaluable " , null ,
48
62
BooleanType ,
49
- Seq .empty ,
63
+ Seq (attrA, attrC) ,
50
64
PythonEvalType .SQL_BATCHED_UDF ,
51
65
udfDeterministic = true )
52
66
@@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
66
80
}
67
81
}
68
82
69
- test(" inner join condition with python udf only " ) {
70
- val query = testRelationLeft.join(
83
+ test(" inner join condition with python udf" ) {
84
+ val query1 = testRelationLeft.join(
71
85
testRelationRight,
72
86
joinType = Inner ,
73
- condition = Some (pythonUDF ))
74
- val expected = testRelationLeft.join(
87
+ condition = Some (unevaluableJoinCond ))
88
+ val expected1 = testRelationLeft.join(
75
89
testRelationRight,
76
90
joinType = Inner ,
77
- condition = None ).where(pythonUDF).analyze
78
- comparePlanWithCrossJoinEnable(query, expected)
91
+ condition = None ).where(unevaluableJoinCond).analyze
92
+ comparePlanWithCrossJoinEnable(query1, expected1)
93
+
94
+ // evaluable PythonUDF will not be touched
95
+ val query2 = testRelationLeft.join(
96
+ testRelationRight,
97
+ joinType = Inner ,
98
+ condition = Some (evaluableJoinCond))
99
+ comparePlans(Optimize .execute(query2), query2)
79
100
}
80
101
81
- test(" left semi join condition with python udf only " ) {
82
- val query = testRelationLeft.join(
102
+ test(" left semi join condition with python udf" ) {
103
+ val query1 = testRelationLeft.join(
83
104
testRelationRight,
84
105
joinType = LeftSemi ,
85
- condition = Some (pythonUDF ))
86
- val expected = testRelationLeft.join(
106
+ condition = Some (unevaluableJoinCond ))
107
+ val expected1 = testRelationLeft.join(
87
108
testRelationRight,
88
109
joinType = Inner ,
89
- condition = None ).where(pythonUDF).select(' a , ' b ).analyze
90
- comparePlanWithCrossJoinEnable(query, expected)
110
+ condition = None ).where(unevaluableJoinCond).select(' a , ' b ).analyze
111
+ comparePlanWithCrossJoinEnable(query1, expected1)
112
+
113
+ // evaluable PythonUDF will not be touched
114
+ val query2 = testRelationLeft.join(
115
+ testRelationRight,
116
+ joinType = LeftSemi ,
117
+ condition = Some (evaluableJoinCond))
118
+ comparePlans(Optimize .execute(query2), query2)
91
119
}
92
120
93
- test(" python udf and common condition" ) {
121
+ test(" unevaluable python udf and common condition" ) {
94
122
val query = testRelationLeft.join(
95
123
testRelationRight,
96
124
joinType = Inner ,
97
- condition = Some (pythonUDF && ' a .attr === ' c .attr))
125
+ condition = Some (unevaluableJoinCond && ' a .attr === ' c .attr))
98
126
val expected = testRelationLeft.join(
99
127
testRelationRight,
100
128
joinType = Inner ,
101
- condition = Some (' a .attr === ' c .attr)).where(pythonUDF ).analyze
129
+ condition = Some (' a .attr === ' c .attr)).where(unevaluableJoinCond ).analyze
102
130
val optimized = Optimize .execute(query.analyze)
103
131
comparePlans(optimized, expected)
104
132
}
105
133
106
- test(" python udf or common condition" ) {
134
+ test(" unevaluable python udf or common condition" ) {
107
135
val query = testRelationLeft.join(
108
136
testRelationRight,
109
137
joinType = Inner ,
110
- condition = Some (pythonUDF || ' a .attr === ' c .attr))
138
+ condition = Some (unevaluableJoinCond || ' a .attr === ' c .attr))
111
139
val expected = testRelationLeft.join(
112
140
testRelationRight,
113
141
joinType = Inner ,
114
- condition = None ).where(pythonUDF || ' a .attr === ' c .attr).analyze
142
+ condition = None ).where(unevaluableJoinCond || ' a .attr === ' c .attr).analyze
115
143
comparePlanWithCrossJoinEnable(query, expected)
116
144
}
117
145
118
- test(" pull out whole complex condition with multiple python udf" ) {
146
+ test(" pull out whole complex condition with multiple unevaluable python udf" ) {
119
147
val pythonUDF1 = PythonUDF (" pythonUDF1" , null ,
120
148
BooleanType ,
121
- Seq .empty ,
149
+ Seq (attrA, attrC) ,
122
150
PythonEvalType .SQL_BATCHED_UDF ,
123
151
udfDeterministic = true )
124
- val condition = (pythonUDF || ' a .attr === ' c .attr) && pythonUDF1
152
+ val condition = (unevaluableJoinCond || ' a .attr === ' c .attr) && pythonUDF1
125
153
126
154
val query = testRelationLeft.join(
127
155
testRelationRight,
@@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
134
162
comparePlanWithCrossJoinEnable(query, expected)
135
163
}
136
164
137
- test(" partial pull out complex condition with multiple python udf" ) {
165
+ test(" partial pull out complex condition with multiple unevaluable python udf" ) {
138
166
val pythonUDF1 = PythonUDF (" pythonUDF1" , null ,
139
167
BooleanType ,
140
- Seq .empty ,
168
+ Seq (attrA, attrC) ,
141
169
PythonEvalType .SQL_BATCHED_UDF ,
142
170
udfDeterministic = true )
143
- val condition = (pythonUDF || pythonUDF1) && ' a .attr === ' c .attr
171
+ val condition = (unevaluableJoinCond || pythonUDF1) && ' a .attr === ' c .attr
144
172
145
173
val query = testRelationLeft.join(
146
174
testRelationRight,
@@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
149
177
val expected = testRelationLeft.join(
150
178
testRelationRight,
151
179
joinType = Inner ,
152
- condition = Some (' a .attr === ' c .attr)).where(pythonUDF || pythonUDF1).analyze
180
+ condition = Some (' a .attr === ' c .attr)).where(unevaluableJoinCond || pythonUDF1).analyze
181
+ val optimized = Optimize .execute(query.analyze)
182
+ comparePlans(optimized, expected)
183
+ }
184
+
185
+ test(" pull out unevaluable python udf when it's mixed with evaluable one" ) {
186
+ val query = testRelationLeft.join(
187
+ testRelationRight,
188
+ joinType = Inner ,
189
+ condition = Some (evaluableJoinCond && unevaluableJoinCond))
190
+ val expected = testRelationLeft.join(
191
+ testRelationRight,
192
+ joinType = Inner ,
193
+ condition = Some (evaluableJoinCond)).where(unevaluableJoinCond).analyze
153
194
val optimized = Optimize .execute(query.analyze)
154
195
comparePlans(optimized, expected)
155
196
}
156
197
157
198
test(" throw an exception for not support join type" ) {
158
199
for (joinType <- unsupportedJoinTypes) {
159
- val thrownException = the [AnalysisException ] thrownBy {
200
+ val e = intercept [AnalysisException ] {
160
201
val query = testRelationLeft.join(
161
202
testRelationRight,
162
203
joinType,
163
- condition = Some (pythonUDF ))
204
+ condition = Some (unevaluableJoinCond ))
164
205
Optimize .execute(query.analyze)
165
206
}
166
- assert(thrownException .message.contentEquals(
207
+ assert(e .message.contentEquals(
167
208
s " Using PythonUDF in join condition of join type $joinType is not supported. " ))
209
+
210
+ val query2 = testRelationLeft.join(
211
+ testRelationRight,
212
+ joinType,
213
+ condition = Some (evaluableJoinCond))
214
+ comparePlans(Optimize .execute(query2), query2)
168
215
}
169
216
}
170
217
}
171
-
0 commit comments