Skip to content

Commit 26ae9e9

Browse files
committed
[SPARK-36559][SQL][PYTHON] Create plans dedicated to distributed-sequence index for optimization
### What changes were proposed in this pull request? This PR proposes to move distributed-sequence index implementation to SQL plan to leverage optimizations such as column pruning. ```python import pyspark.pandas as ps ps.set_option('compute.default_index_type', 'distributed-sequence') ps.range(10).id.value_counts().to_frame().spark.explain() ``` **Before:** ```bash == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- Sort [count#51L DESC NULLS LAST], true, 0 +- Exchange rangepartitioning(count#51L DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [id=#70] +- HashAggregate(keys=[id#37L], functions=[count(1)], output=[__index_level_0__#48L, count#51L]) +- Exchange hashpartitioning(id#37L, 200), ENSURE_REQUIREMENTS, [id=#67] +- HashAggregate(keys=[id#37L], functions=[partial_count(1)], output=[id#37L, count#63L]) +- Project [id#37L] +- Filter atleastnnonnulls(1, id#37L) +- Scan ExistingRDD[__index_level_0__#36L,id#37L] # ^^^ Base DataFrame created by the output RDD from zipWithIndex (and checkpointed) ``` **After:** ```bash == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- Sort [count#275L DESC NULLS LAST], true, 0 +- Exchange rangepartitioning(count#275L DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [id=#174] +- HashAggregate(keys=[id#258L], functions=[count(1)]) +- HashAggregate(keys=[id#258L], functions=[partial_count(1)]) +- Filter atleastnnonnulls(1, id#258L) +- Range (0, 10, step=1, splits=16) # ^^^ Removed the Spark job execution for `zipWithIndex` ``` ### Why are the changes needed? To leverage optimization of SQL engine and avoid unnecessary shuffle to create default index. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unittests were added. Also, this PR will test all unittests in pandas API on Spark after switching the default index implementation to `distributed-sequence`. Closes #33807 from HyukjinKwon/SPARK-36559. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 93cec49) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 5463caa commit 26ae9e9

File tree

8 files changed

+121
-39
lines changed

8 files changed

+121
-39
lines changed

python/pyspark/pandas/tests/test_dataframe.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5160,26 +5160,25 @@ def test_print_schema(self):
51605160
sys.stdout = prev
51615161

51625162
def test_explain_hint(self):
5163-
with ps.option_context("compute.default_index_type", "sequence"):
5164-
psdf1 = ps.DataFrame(
5165-
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
5166-
columns=["lkey", "value"],
5167-
)
5168-
psdf2 = ps.DataFrame(
5169-
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
5170-
columns=["rkey", "value"],
5171-
)
5172-
merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey")
5173-
prev = sys.stdout
5174-
try:
5175-
out = StringIO()
5176-
sys.stdout = out
5177-
merged.spark.explain()
5178-
actual = out.getvalue().strip()
5179-
5180-
self.assertTrue("Broadcast" in actual, actual)
5181-
finally:
5182-
sys.stdout = prev
5163+
psdf1 = ps.DataFrame(
5164+
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
5165+
columns=["lkey", "value"],
5166+
)
5167+
psdf2 = ps.DataFrame(
5168+
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
5169+
columns=["rkey", "value"],
5170+
)
5171+
merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey")
5172+
prev = sys.stdout
5173+
try:
5174+
out = StringIO()
5175+
sys.stdout = out
5176+
merged.spark.explain()
5177+
actual = out.getvalue().strip()
5178+
5179+
self.assertTrue("Broadcast" in actual, actual)
5180+
finally:
5181+
sys.stdout = prev
51835182

51845183
def test_mad(self):
51855184
pdf = pd.DataFrame(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
225225
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
226226
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
227227

228+
case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
229+
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
230+
Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())))
231+
228232
case oldVersion: Generate
229233
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
230234
val newOutput = oldVersion.generatorOutput.map(_.newInstance())

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
794794
}
795795
a.copy(child = Expand(newProjects, newOutput, grandChild))
796796

797+
// Prune and drop AttachDistributedSequence if the produced attribute is not referred.
798+
case p @ Project(_, a @ AttachDistributedSequence(_, grandChild))
799+
if !p.references.contains(a.sequenceAttr) =>
800+
p.copy(child = prunedChild(grandChild, p.references))
801+
797802
// Prunes the unused columns from child of `DeserializeToObject`
798803
case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
799804
d.copy(child = prunedChild(child, d.references))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,20 @@ case class ArrowEvalPython(
115115
override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython =
116116
copy(child = newChild)
117117
}
118+
119+
/**
120+
* A logical plan that adds a new long column with the name `name` that
121+
* increases one by one. This is for 'distributed-sequence' default index
122+
* in pandas API on Spark.
123+
*/
124+
case class AttachDistributedSequence(
125+
sequenceAttr: Attribute,
126+
child: LogicalPlan) extends UnaryNode {
127+
128+
override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
129+
130+
override val output: Seq[Attribute] = sequenceAttr +: child.output
131+
132+
override protected def withNewChildInternal(newChild: LogicalPlan): AttachDistributedSequence =
133+
copy(child = newChild)
134+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,5 +452,11 @@ class ColumnPruningSuite extends PlanTest {
452452
val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze
453453
comparePlans(optimized, expected)
454454
}
455-
// todo: add more tests for column pruning
455+
456+
test("SPARK-36559 Prune and drop distributed-sequence if the produced column is not referred") {
457+
val input = LocalRelation('a.int, 'b.int, 'c.int)
458+
val plan1 = AttachDistributedSequence('d.int, input).select('a)
459+
val correctAnswer1 = Project(Seq('a), input).analyze
460+
comparePlans(Optimize.execute(plan1.analyze), correctAnswer1)
461+
}
456462
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3514,24 +3514,11 @@ class Dataset[T] private[sql](
35143514
* This is for 'distributed-sequence' default index in pandas API on Spark.
35153515
*/
35163516
private[sql] def withSequenceColumn(name: String) = {
3517-
val rdd: RDD[InternalRow] =
3518-
// Checkpoint the DataFrame to fix the partition ID.
3519-
localCheckpoint(false)
3520-
.queryExecution.toRdd.zipWithIndex().mapPartitions { iter =>
3521-
val joinedRow = new JoinedRow
3522-
val unsafeRowWriter =
3523-
new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1)
3524-
3525-
iter.map { case (row, id) =>
3526-
// Writes to an UnsafeRow directly
3527-
unsafeRowWriter.reset()
3528-
unsafeRowWriter.write(0, id)
3529-
joinedRow(unsafeRowWriter.getRow, row)
3530-
}
3531-
}
3532-
3533-
sparkSession.internalCreateDataFrame(
3534-
rdd, StructType(StructField(name, LongType, nullable = false) +: schema), isStreaming)
3517+
Dataset.ofRows(
3518+
sparkSession,
3519+
AttachDistributedSequence(
3520+
AttributeReference(name, LongType, nullable = false)(),
3521+
logicalPlan))
35353522
}
35363523

35373524
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
712712
func, output, planLater(left), planLater(right)) :: Nil
713713
case logical.MapInPandas(func, output, child) =>
714714
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
715+
case logical.AttachDistributedSequence(attr, child) =>
716+
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
715717
case logical.MapElements(f, _, _, objAttr, child) =>
716718
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
717719
case logical.AppendColumns(f, _, _, in, out, child) =>
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.python
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.physical._
24+
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
25+
26+
/**
27+
* A physical plan that adds a new long column with `sequenceAttr` that
28+
* increases one by one. This is for 'distributed-sequence' default index
29+
* in pandas API on Spark.
30+
*/
31+
case class AttachDistributedSequenceExec(
32+
sequenceAttr: Attribute,
33+
child: SparkPlan)
34+
extends UnaryExecNode {
35+
36+
override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
37+
38+
override val output: Seq[Attribute] = sequenceAttr +: child.output
39+
40+
override def outputPartitioning: Partitioning = child.outputPartitioning
41+
42+
override protected def doExecute(): RDD[InternalRow] = {
43+
child.execute().map(_.copy())
44+
.localCheckpoint() // to avoid execute multiple jobs. zipWithIndex launches a Spark job.
45+
.zipWithIndex().mapPartitions { iter =>
46+
val unsafeProj = UnsafeProjection.create(output, output)
47+
val joinedRow = new JoinedRow
48+
val unsafeRowWriter =
49+
new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1)
50+
51+
iter.map { case (row, id) =>
52+
// Writes to an UnsafeRow directly
53+
unsafeRowWriter.reset()
54+
unsafeRowWriter.write(0, id)
55+
joinedRow(unsafeRowWriter.getRow, row)
56+
}.map(unsafeProj)
57+
}
58+
}
59+
60+
override protected def withNewChildInternal(newChild: SparkPlan): AttachDistributedSequenceExec =
61+
copy(child = newChild)
62+
}

0 commit comments

Comments
 (0)