Skip to content

Commit 4b76c16

Browse files
maropurshkv
authored andcommitted
[SPARK-31594][SQL] Do not display the seed of rand/randn with no argument in output schema
This PR intends to update `sql` in `Rand`/`Randn` with no argument to make a column name deterministic. Before this PR (a column name changes run-by-run): ``` scala> sql("select rand()").show() +-------------------------+ |rand(7986133828002692830)| +-------------------------+ | 0.9524061403696937| +-------------------------+ ``` After this PR (a column name fixed): ``` scala> sql("select rand()").show() +------------------+ | rand()| +------------------+ |0.7137935639522275| +------------------+ // If a seed given, it is still shown in a column name // (the same with the current behaviour) scala> sql("select rand(1)").show() +------------------+ | rand(1)| +------------------+ |0.6363787615254752| +------------------+ // We can still check a seed in explain output: scala> sql("select rand()").explain() == Physical Plan == *(1) Project [rand(-2282124938778456838) AS rand()#0] +- *(1) Scan OneRowRelation[] ``` Note: This fix comes from apache#28194; the ongoing PR tests the output schema of expressions, so their schemas must be deterministic for the tests. To make output schema deterministic. No. Added unit tests. Closes apache#28392 from maropu/SPARK-31594. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 5329a42 commit 4b76c16

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,12 @@ trait ExpressionWithRandomSeed {
8383
""",
8484
since = "1.5.0")
8585
// scalastyle:on line.size.limit
86-
case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed {
86+
case class Rand(child: Expression, hideSeed: Boolean = false)
87+
extends RDG with ExpressionWithRandomSeed {
8788

88-
def this() = this(Literal(Utils.random.nextLong(), LongType))
89+
def this() = this(Literal(Utils.random.nextLong(), LongType), true)
90+
91+
def this(child: Expression) = this(child, false)
8992

9093
override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType))
9194

@@ -101,7 +104,12 @@ case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed {
101104
isNull = FalseLiteral)
102105
}
103106

104-
override def freshCopy(): Rand = Rand(child)
107+
override def freshCopy(): Rand = Rand(child, hideSeed)
108+
109+
override def flatArguments: Iterator[Any] = Iterator(child)
110+
override def sql: String = {
111+
s"rand(${if (hideSeed) "" else child.sql})"
112+
}
105113
}
106114

107115
object Rand {
@@ -126,9 +134,12 @@ object Rand {
126134
""",
127135
since = "1.5.0")
128136
// scalastyle:on line.size.limit
129-
case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed {
137+
case class Randn(child: Expression, hideSeed: Boolean = false)
138+
extends RDG with ExpressionWithRandomSeed {
130139

131-
def this() = this(Literal(Utils.random.nextLong(), LongType))
140+
def this() = this(Literal(Utils.random.nextLong(), LongType), true)
141+
142+
def this(child: Expression) = this(child, false)
132143

133144
override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType))
134145

@@ -144,7 +155,12 @@ case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed {
144155
isNull = FalseLiteral)
145156
}
146157

147-
override def freshCopy(): Randn = Randn(child)
158+
override def freshCopy(): Randn = Randn(child, hideSeed)
159+
160+
override def flatArguments: Iterator[Any] = Iterator(child)
161+
override def sql: String = {
162+
s"randn(${if (hideSeed) "" else child.sql})"
163+
}
148164
}
149165

150166
object Randn {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,11 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
3434
checkEvaluation(Rand(5419823303878592871L), 0.7145363364564755)
3535
checkEvaluation(Randn(5419823303878592871L), 0.7816815274533012)
3636
}
37+
38+
test("SPARK-31594: Do not display the seed of rand/randn with no argument in output schema") {
39+
assert(Rand(Literal(1L), true).sql === "rand()")
40+
assert(Randn(Literal(1L), true).sql === "randn()")
41+
assert(Rand(Literal(1L), false).sql === "rand(1L)")
42+
assert(Randn(Literal(1L), false).sql === "randn(1L)")
43+
}
3744
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3443,6 +3443,29 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
34433443
}
34443444
}
34453445

3446+
test("SPARK-31594: Do not display the seed of rand/randn with no argument in output schema") {
3447+
def checkIfSeedExistsInExplain(df: DataFrame): Unit = {
3448+
val output = new java.io.ByteArrayOutputStream()
3449+
Console.withOut(output) {
3450+
df.explain()
3451+
}
3452+
val projectExplainOutput = output.toString.split("\n").find(_.contains("Project")).get
3453+
assert(projectExplainOutput.matches(""".*randn?\(-?[0-9]+\).*"""))
3454+
}
3455+
val df1 = sql("SELECT rand()")
3456+
assert(df1.schema.head.name === "rand()")
3457+
checkIfSeedExistsInExplain(df1)
3458+
val df2 = sql("SELECT rand(1L)")
3459+
assert(df2.schema.head.name === "rand(1)")
3460+
checkIfSeedExistsInExplain(df2)
3461+
val df3 = sql("SELECT randn()")
3462+
assert(df3.schema.head.name === "randn()")
3463+
checkIfSeedExistsInExplain(df1)
3464+
val df4 = sql("SELECT randn(1L)")
3465+
assert(df4.schema.head.name === "randn(1)")
3466+
checkIfSeedExistsInExplain(df2)
3467+
}
3468+
34463469
test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") {
34473470
checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1)))
34483471
checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),

0 commit comments

Comments
 (0)