Skip to content

Commit 93cec49

Browse files
committed
[SPARK-36559][SQL][PYTHON] Create plans dedicated to distributed-sequence index for optimization
### What changes were proposed in this pull request? This PR proposes to move distributed-sequence index implementation to SQL plan to leverage optimizations such as column pruning. ```python import pyspark.pandas as ps ps.set_option('compute.default_index_type', 'distributed-sequence') ps.range(10).id.value_counts().to_frame().spark.explain() ``` **Before:** ```bash == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- Sort [count#51L DESC NULLS LAST], true, 0 +- Exchange rangepartitioning(count#51L DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [id=#70] +- HashAggregate(keys=[id#37L], functions=[count(1)], output=[__index_level_0__#48L, count#51L]) +- Exchange hashpartitioning(id#37L, 200), ENSURE_REQUIREMENTS, [id=#67] +- HashAggregate(keys=[id#37L], functions=[partial_count(1)], output=[id#37L, count#63L]) +- Project [id#37L] +- Filter atleastnnonnulls(1, id#37L) +- Scan ExistingRDD[__index_level_0__#36L,id#37L] # ^^^ Base DataFrame created by the output RDD from zipWithIndex (and checkpointed) ``` **After:** ```bash == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- Sort [count#275L DESC NULLS LAST], true, 0 +- Exchange rangepartitioning(count#275L DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [id=#174] +- HashAggregate(keys=[id#258L], functions=[count(1)]) +- HashAggregate(keys=[id#258L], functions=[partial_count(1)]) +- Filter atleastnnonnulls(1, id#258L) +- Range (0, 10, step=1, splits=16) # ^^^ Removed the Spark job execution for `zipWithIndex` ``` ### Why are the changes needed? To leverage optimization of SQL engine and avoid unnecessary shuffle to create default index. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unittests were added. Also, this PR will test all unittests in pandas API on Spark after switching the default index implementation to `distributed-sequence`. Closes #33807 from HyukjinKwon/SPARK-36559. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 3e32ea1 commit 93cec49

File tree

8 files changed

+121
-39
lines changed

8 files changed

+121
-39
lines changed

python/pyspark/pandas/tests/test_dataframe.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5160,26 +5160,25 @@ def test_print_schema(self):
51605160
sys.stdout = prev
51615161

51625162
def test_explain_hint(self):
5163-
with ps.option_context("compute.default_index_type", "sequence"):
5164-
psdf1 = ps.DataFrame(
5165-
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
5166-
columns=["lkey", "value"],
5167-
)
5168-
psdf2 = ps.DataFrame(
5169-
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
5170-
columns=["rkey", "value"],
5171-
)
5172-
merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey")
5173-
prev = sys.stdout
5174-
try:
5175-
out = StringIO()
5176-
sys.stdout = out
5177-
merged.spark.explain()
5178-
actual = out.getvalue().strip()
5179-
5180-
self.assertTrue("Broadcast" in actual, actual)
5181-
finally:
5182-
sys.stdout = prev
5163+
psdf1 = ps.DataFrame(
5164+
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
5165+
columns=["lkey", "value"],
5166+
)
5167+
psdf2 = ps.DataFrame(
5168+
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
5169+
columns=["rkey", "value"],
5170+
)
5171+
merged = psdf1.merge(psdf2.spark.hint("broadcast"), left_on="lkey", right_on="rkey")
5172+
prev = sys.stdout
5173+
try:
5174+
out = StringIO()
5175+
sys.stdout = out
5176+
merged.spark.explain()
5177+
actual = out.getvalue().strip()
5178+
5179+
self.assertTrue("Broadcast" in actual, actual)
5180+
finally:
5181+
sys.stdout = prev
51835182

51845183
def test_mad(self):
51855184
pdf = pd.DataFrame(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
225225
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
226226
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
227227

228+
case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
229+
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
230+
Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())))
231+
228232
case oldVersion: Generate
229233
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
230234
val newOutput = oldVersion.generatorOutput.map(_.newInstance())

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
800800
}
801801
a.copy(child = Expand(newProjects, newOutput, grandChild))
802802

803+
// Prune and drop AttachDistributedSequence if the produced attribute is not referred.
804+
case p @ Project(_, a @ AttachDistributedSequence(_, grandChild))
805+
if !p.references.contains(a.sequenceAttr) =>
806+
p.copy(child = prunedChild(grandChild, p.references))
807+
803808
// Prunes the unused columns from child of `DeserializeToObject`
804809
case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
805810
d.copy(child = prunedChild(child, d.references))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,20 @@ case class ArrowEvalPython(
115115
override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython =
116116
copy(child = newChild)
117117
}
118+
119+
/**
120+
* A logical plan that adds a new long column with the name `name` that
121+
* increases one by one. This is for 'distributed-sequence' default index
122+
* in pandas API on Spark.
123+
*/
124+
case class AttachDistributedSequence(
125+
sequenceAttr: Attribute,
126+
child: LogicalPlan) extends UnaryNode {
127+
128+
override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
129+
130+
override val output: Seq[Attribute] = sequenceAttr +: child.output
131+
132+
override protected def withNewChildInternal(newChild: LogicalPlan): AttachDistributedSequence =
133+
copy(child = newChild)
134+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,5 +452,11 @@ class ColumnPruningSuite extends PlanTest {
452452
val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze
453453
comparePlans(optimized, expected)
454454
}
455-
// todo: add more tests for column pruning
455+
456+
test("SPARK-36559 Prune and drop distributed-sequence if the produced column is not referred") {
457+
val input = LocalRelation('a.int, 'b.int, 'c.int)
458+
val plan1 = AttachDistributedSequence('d.int, input).select('a)
459+
val correctAnswer1 = Project(Seq('a), input).analyze
460+
comparePlans(Optimize.execute(plan1.analyze), correctAnswer1)
461+
}
456462
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3541,24 +3541,11 @@ class Dataset[T] private[sql](
35413541
* This is for 'distributed-sequence' default index in pandas API on Spark.
35423542
*/
35433543
private[sql] def withSequenceColumn(name: String) = {
3544-
val rdd: RDD[InternalRow] =
3545-
// Checkpoint the DataFrame to fix the partition ID.
3546-
localCheckpoint(false)
3547-
.queryExecution.toRdd.zipWithIndex().mapPartitions { iter =>
3548-
val joinedRow = new JoinedRow
3549-
val unsafeRowWriter =
3550-
new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1)
3551-
3552-
iter.map { case (row, id) =>
3553-
// Writes to an UnsafeRow directly
3554-
unsafeRowWriter.reset()
3555-
unsafeRowWriter.write(0, id)
3556-
joinedRow(unsafeRowWriter.getRow, row)
3557-
}
3558-
}
3559-
3560-
sparkSession.internalCreateDataFrame(
3561-
rdd, StructType(StructField(name, LongType, nullable = false) +: schema), isStreaming)
3544+
Dataset.ofRows(
3545+
sparkSession,
3546+
AttachDistributedSequence(
3547+
AttributeReference(name, LongType, nullable = false)(),
3548+
logicalPlan))
35623549
}
35633550

35643551
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
754754
func, output, planLater(left), planLater(right)) :: Nil
755755
case logical.MapInPandas(func, output, child) =>
756756
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
757+
case logical.AttachDistributedSequence(attr, child) =>
758+
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
757759
case logical.MapElements(f, _, _, objAttr, child) =>
758760
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
759761
case logical.AppendColumns(f, _, _, in, out, child) =>
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.python
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.physical._
24+
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
25+
26+
/**
27+
* A physical plan that adds a new long column with `sequenceAttr` that
28+
* increases one by one. This is for 'distributed-sequence' default index
29+
* in pandas API on Spark.
30+
*/
31+
case class AttachDistributedSequenceExec(
32+
sequenceAttr: Attribute,
33+
child: SparkPlan)
34+
extends UnaryExecNode {
35+
36+
override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
37+
38+
override val output: Seq[Attribute] = sequenceAttr +: child.output
39+
40+
override def outputPartitioning: Partitioning = child.outputPartitioning
41+
42+
override protected def doExecute(): RDD[InternalRow] = {
43+
child.execute().map(_.copy())
44+
.localCheckpoint() // to avoid execute multiple jobs. zipWithIndex launches a Spark job.
45+
.zipWithIndex().mapPartitions { iter =>
46+
val unsafeProj = UnsafeProjection.create(output, output)
47+
val joinedRow = new JoinedRow
48+
val unsafeRowWriter =
49+
new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1)
50+
51+
iter.map { case (row, id) =>
52+
// Writes to an UnsafeRow directly
53+
unsafeRowWriter.reset()
54+
unsafeRowWriter.write(0, id)
55+
joinedRow(unsafeRowWriter.getRow, row)
56+
}.map(unsafeProj)
57+
}
58+
}
59+
60+
override protected def withNewChildInternal(newChild: SparkPlan): AttachDistributedSequenceExec =
61+
copy(child = newChild)
62+
}

0 commit comments

Comments
 (0)