Skip to content

[SPARK-47527][SQL] Normalize common expression ids during canonicalization of With expression #45677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ case class With(child: Expression, defs: Seq[CommonExpressionDef])
newChildren: IndexedSeq[Expression]): Expression = {
copy(child = newChildren.head, defs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
}

override lazy val canonicalized: Expression = {
// we first must normalize the ids in the common expressions. We must do this before
// other canonicalization steps, since canonicalization might reorder commutative expressions
// that include common expressions as operands, and the id influences such ordering.
val ceIdMap = defs.map(_.id).zip(0L until defs.size).toMap
val newChild = child.transform {
case r: CommonExpressionRef => r.copy(id = ceIdMap.getOrElse(r.id, r.id))
}.canonicalized
val newDefs = defs.map { d =>
d.copy(id = ceIdMap.getOrElse(d.id, d.id)).canonicalized.asInstanceOf[CommonExpressionDef]
}
With(newChild, newDefs)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

class WithSuite extends PlanTest {
private val testRelation = LocalRelation($"a".int, $"b".int)

test("SPARK-47527: Canonicalization") {
val a = testRelation.output.head
val b = testRelation.output.last
val def1 = CommonExpressionDef(a > 2, 7)
val def2 = CommonExpressionDef(b > 3, 8)
val ref1 = new CommonExpressionRef(def1)
val ref2 = new CommonExpressionRef(def2)
val withExpr1 = With(And(ref1, ref2), Seq(def1, def2))

val def3 = CommonExpressionDef(a > 2, 12)
val def4 = CommonExpressionDef(b > 3, 13)
val ref3 = new CommonExpressionRef(def3)
val ref4 = new CommonExpressionRef(def4)
val withExpr2 = With(And(ref3, ref4), Seq(def3, def4))

assert(withExpr1.semanticEquals(withExpr2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1770,4 +1770,33 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
withSQLConf(SQLConf.DEFAULT_CACHE_STORAGE_LEVEL.key -> "DISK") {}
}
}

test("SPARK-47527: Cache should be hit for queries using common expressions") {
withTempView("data", "the_query", "the_query2") {
spark.range(10)
.selectExpr("id", "id * 10").toDF("id", "val")
.createOrReplaceTempView("data")

sql("""create or replace temp view the_query as
|select *
|from data
|where id between 2 and 4""".stripMargin)
sql("cache table the_query")
val df1 = sql("SELECT * FROM the_query order by id")
checkAnswer(df1,
Row(2, 20) :: Row(3, 30) :: Row(4, 40) :: Nil)
assert(getNumInMemoryRelations(df1) == 1)

sql("""create or replace temp view the_query2 as
|select id, count_if(val > 0) as snt
|from data
|group by id""".stripMargin)
sql("cache table the_query2")
val df2 = sql("SELECT * FROM the_query2 order by id")
checkAnswer(df2,
Row(0, 0) :: Row(1, 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: Row(5, 1) ::
Row(6, 1) :: Row(7, 1) :: Row(8, 1) :: Row(9, 1) :: Nil)
assert(getNumInMemoryRelations(df2) == 1)
}
}
}