Skip to content

Commit 9c30116

Browse files
ulysses-youdongjoon-hyun
authored andcommitted
[SPARK-33857][SQL] Unify the default seed of random functions
### What changes were proposed in this pull request? Unify the seed of random functions 1. Add a hold place expression `UnresolvedSeed ` as the defualt seed. 2. Change `Rand`,`Randn`,`Uuid`,`Shuffle` default seed to `UnresolvedSeed `. 3. Replace `UnresolvedSeed ` to real seed at `ResolveRandomSeed` rule. ### Why are the changes needed? `Uuid` and `Shuffle` use the `ResolveRandomSeed` rule to set the seed if user doesn't give a seed value. `Rand` and `Randn` do this at constructing. It's better to unify the default seed at Analyzer side since we have used `ExpressionWithRandomSeed` at streaming query. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass exists test and add test. Closes #30864 from ulysses-you/SPARK-33857. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 700f5ab commit 9c30116

File tree

6 files changed

+42
-14
lines changed

6 files changed

+42
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,8 +3000,8 @@ class Analyzer(override val catalogManager: CatalogManager)
30003000
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
30013001
case p if p.resolved => p
30023002
case p => p transformExpressionsUp {
3003-
case Uuid(None) => Uuid(Some(random.nextLong()))
3004-
case Shuffle(child, None) => Shuffle(child, Some(random.nextLong()))
3003+
case e: ExpressionWithRandomSeed if e.seedExpression == UnresolvedSeed =>
3004+
e.withNewSeed(random.nextLong())
30053005
}
30063006
}
30073007
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,12 @@ case class UnresolvedHaving(
561561
override lazy val resolved: Boolean = false
562562
override def output: Seq[Attribute] = child.output
563563
}
564+
565+
/**
566+
* A place holder expression used in random functions, will be replaced after analyze.
567+
*/
568+
case object UnresolvedSeed extends LeafExpression with Unevaluable {
569+
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
570+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
571+
override lazy val resolved = false
572+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.collection.mutable
2323
import scala.reflect.ClassTag
2424

2525
import org.apache.spark.sql.catalyst.InternalRow
26-
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
26+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedSeed}
2727
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
2828
import org.apache.spark.sql.catalyst.expressions.codegen._
2929
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -943,6 +943,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
943943

944944
def this(child: Expression) = this(child, None)
945945

946+
override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)
947+
946948
override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
947949

948950
override lazy val resolved: Boolean =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
2223
import org.apache.spark.sql.catalyst.expressions.codegen._
2324
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2425
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
@@ -187,6 +188,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
187188

188189
def this() = this(None)
189190

191+
override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)
192+
190193
override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))
191194

192195
override lazy val resolved: Boolean = randomSeed.isDefined

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
2223
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
2324
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2425
import org.apache.spark.sql.types._
25-
import org.apache.spark.util.Utils
2626
import org.apache.spark.util.random.XORShiftRandom
2727

2828
/**
@@ -32,7 +32,8 @@ import org.apache.spark.util.random.XORShiftRandom
3232
*
3333
* Since this expression is stateful, it cannot be a case object.
3434
*/
35-
abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful {
35+
abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
36+
with ExpressionWithRandomSeed {
3637
/**
3738
* Record ID within each partition. By being transient, the Random Number Generator is
3839
* reset every time we serialize and deserialize and initialize it.
@@ -43,7 +44,9 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
4344
rng = new XORShiftRandom(seed + partitionIndex)
4445
}
4546

46-
@transient protected lazy val seed: Long = child match {
47+
override def seedExpression: Expression = child
48+
49+
@transient protected lazy val seed: Long = seedExpression match {
4750
case Literal(s, IntegerType) => s.asInstanceOf[Int]
4851
case Literal(s, LongType) => s.asInstanceOf[Long]
4952
case _ => throw new AnalysisException(
@@ -62,6 +65,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
6265
* Usually the random seed needs to be renewed at each execution under streaming queries.
6366
*/
6467
trait ExpressionWithRandomSeed {
68+
def seedExpression: Expression
6569
def withNewSeed(seed: Long): Expression
6670
}
6771

@@ -84,14 +88,13 @@ trait ExpressionWithRandomSeed {
8488
since = "1.5.0",
8589
group = "math_funcs")
8690
// scalastyle:on line.size.limit
87-
case class Rand(child: Expression, hideSeed: Boolean = false)
88-
extends RDG with ExpressionWithRandomSeed {
91+
case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
8992

90-
def this() = this(Literal(Utils.random.nextLong(), LongType), true)
93+
def this() = this(UnresolvedSeed, true)
9194

9295
def this(child: Expression) = this(child, false)
9396

94-
override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType))
97+
override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed)
9598

9699
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
97100

@@ -136,14 +139,13 @@ object Rand {
136139
since = "1.5.0",
137140
group = "math_funcs")
138141
// scalastyle:on line.size.limit
139-
case class Randn(child: Expression, hideSeed: Boolean = false)
140-
extends RDG with ExpressionWithRandomSeed {
142+
case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
141143

142-
def this() = this(Literal(Utils.random.nextLong(), LongType), true)
144+
def this() = this(UnresolvedSeed, true)
143145

144146
def this(child: Expression) = this(child, false)
145147

146-
override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType))
148+
override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed)
147149

148150
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
149151

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,4 +1006,16 @@ class AnalysisSuite extends AnalysisTest with Matchers {
10061006
checkAnalysis(plan, expect)
10071007
}
10081008
}
1009+
1010+
test("SPARK-33857: Unify the default seed of random functions") {
1011+
Seq(new Rand(), new Randn(), Shuffle(Literal(Array(1))), Uuid()).foreach { r =>
1012+
assert(r.seedExpression == UnresolvedSeed)
1013+
val p = getAnalyzer.execute(Project(Seq(r.as("r")), testRelation))
1014+
assert(
1015+
p.asInstanceOf[Project].projectList.head.asInstanceOf[Alias]
1016+
.child.asInstanceOf[ExpressionWithRandomSeed]
1017+
.seedExpression.isInstanceOf[Literal]
1018+
)
1019+
}
1020+
}
10091021
}

0 commit comments

Comments
 (0)