Skip to content

Commit 33a0e9a

Browse files
fix tests
1 parent c78b0f0 commit 33a0e9a

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,12 +733,12 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {
733733
*/
734734
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
735735
case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
736-
if right.plan.subqueriesAll.isEmpty =>
736+
if right.plan.subqueriesAll.isEmpty && right.joinCond.isEmpty =>
737737
Project(left.output ++ projectList, left)
738738
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
739739
_.containsPattern(SCALAR_SUBQUERY)) {
740740
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
741-
if s.plan.subqueriesAll.isEmpty =>
741+
if s.plan.subqueriesAll.isEmpty && s.joinCond.isEmpty =>
742742
assert(projectList.size == 1)
743743
projectList.head
744744
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
3232
val x = AttributeReference("x", IntegerType)()
3333
val y = AttributeReference("y", IntegerType)()
3434
val z = AttributeReference("z", IntegerType)()
35+
val t0 = OneRowRelation()
3536
val testRelation = LocalRelation(a, b, c)
3637
val testRelation2 = LocalRelation(x, y, z)
3738

@@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest {
203204

204205
test("correlated values in project") {
205206
val outerPlan = testRelation2
206-
val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation())
207-
val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation()))
207+
val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0)
208+
val correctAnswer = Project(
209+
Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0))
208210
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
209211
}
210212

211213
test("correlated values in project with alias") {
212214
val outerPlan = testRelation2
213215
val innerPlan =
214-
Project(Seq(OuterReference(x), 'y1, 'sum),
216+
Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum),
215217
Project(Seq(
216218
OuterReference(x),
217219
OuterReference(y).as("y1"),
218220
Add(OuterReference(x), OuterReference(y)).as("sum")),
219221
testRelation)).analyze
220222
val correctAnswer =
221-
Project(Seq(x, 'y1, 'sum, y),
222-
Project(Seq(x, y.as("y1"), (x + y).as("sum"), y),
223+
Project(Seq(x.as("x1"), 'y1, 'sum, x, y),
224+
Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y),
223225
DomainJoin(Seq(x, y), testRelation))).analyze
224226
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
225227
}
@@ -228,28 +230,28 @@ class DecorrelateInnerQuerySuite extends PlanTest {
228230
val outerPlan = testRelation2
229231
val innerPlan =
230232
Project(
231-
Seq(OuterReference(x)),
233+
Seq(OuterReference(x).as("x1")),
232234
Filter(
233235
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
234236
testRelation
235237
)
236238
)
237-
val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation))
239+
val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation))
238240
check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c))
239241
}
240242

241243
test("correlated values in project without correlated equality conditions in filter") {
242244
val outerPlan = testRelation2
243245
val innerPlan =
244246
Project(
245-
Seq(OuterReference(y)),
247+
Seq(OuterReference(y).as("y1")),
246248
Filter(
247249
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
248250
testRelation
249251
)
250252
)
251253
val correctAnswer =
252-
Project(Seq(y, a, c),
254+
Project(Seq(y.as("y1"), y, a, c),
253255
Filter(b === 1,
254256
DomainJoin(Seq(y), testRelation)
255257
)

0 commit comments

Comments
 (0)