-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add randLaplace and native randGamma (#156)
* Init RandGamma * add randGamma col function * add test for gamma and laplace distribution column function * add shape and scale to flat argument
- Loading branch information
Showing
6 changed files
with
191 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -750,5 +750,4 @@ object functions { | |
def excelEpochToDate(colName: String): Column = { | ||
excelEpochToDate(col(colName)) | ||
} | ||
|
||
} |
84 changes: 84 additions & 0 deletions
84
src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.commons.math3.distribution.GammaDistribution | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed | ||
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral | ||
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.util.random.XORShiftRandomAdapted | ||
|
||
case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) extends TernaryExpression | ||
with ExpectsInputTypes | ||
with Stateful | ||
with ExpressionWithRandomSeed { | ||
|
||
override def seedExpression: Expression = child | ||
|
||
@transient protected lazy val seed: Long = seedExpression match { | ||
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] | ||
case e if e.dataType == LongType => e.eval().asInstanceOf[Long] | ||
} | ||
|
||
@transient protected lazy val shapeVal: Double = shape.dataType match { | ||
case IntegerType => shape.eval().asInstanceOf[Int] | ||
case LongType => shape.eval().asInstanceOf[Long] | ||
case FloatType | DoubleType => shape.eval().asInstanceOf[Double] | ||
} | ||
|
||
@transient protected lazy val scaleVal: Double = scale.dataType match { | ||
case IntegerType => scale.eval().asInstanceOf[Int] | ||
case LongType => scale.eval().asInstanceOf[Long] | ||
case FloatType | DoubleType => scale.eval().asInstanceOf[Double] | ||
} | ||
|
||
override protected def initializeInternal(partitionIndex: Int): Unit = { | ||
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal) | ||
} | ||
@transient private var distribution: GammaDistribution = _ | ||
|
||
def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) | ||
|
||
def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false) | ||
|
||
override def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed) | ||
|
||
override protected def evalInternal(input: InternalRow): Double = distribution.sample() | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val distributionClassName = classOf[GammaDistribution].getName | ||
val rngClassName = classOf[XORShiftRandomAdapted].getName | ||
val disTerm = ctx.addMutableState(distributionClassName, "distribution") | ||
ctx.addPartitionInitializationStatement( | ||
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);") | ||
ev.copy(code = code""" | ||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", | ||
isNull = FalseLiteral) | ||
} | ||
|
||
override def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed) | ||
|
||
override def flatArguments: Iterator[Any] = Iterator(child, shape, scale) | ||
|
||
override def prettyName: String = "rand_gamma" | ||
|
||
override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})" | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType) | ||
|
||
override def dataType: DataType = DoubleType | ||
|
||
override def first: Expression = child | ||
|
||
override def second: Expression = shape | ||
|
||
override def third: Expression = scale | ||
|
||
override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = | ||
copy(child = newFirst, shape = newSecond, scale = newThird) | ||
} | ||
|
||
object RandGamma { | ||
def apply(seed: Long, shape: Double, scale: Double): RandGamma = RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package org.apache.spark.sql.daria | ||
|
||
import org.apache.spark.sql.Column | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, RandGamma} | ||
import org.apache.spark.sql.functions.{lit, log, rand, when} | ||
import org.apache.spark.util.Utils | ||
|
||
object functions { | ||
private def withExpr(expr: Expression): Column = Column(expr) | ||
|
||
def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") | ||
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale) | ||
def randGamma(): Column = randGamma(1.0, 1.0) | ||
|
||
def randLaplace(seed: Long, mu: Double, beta: Double): Column = { | ||
val mu_ = lit(mu) | ||
val beta_ = lit(beta) | ||
val u = rand(seed) | ||
when(u < 0.5, mu_ + beta_ * log(lit(2) * u)) | ||
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u))) | ||
.alias("laplace_random") | ||
} | ||
|
||
def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta) | ||
def randLaplace(): Column = randLaplace(0.0, 1.0) | ||
} |
32 changes: 32 additions & 0 deletions
32
src/main/scala/org/apache/spark/util/random/XORShiftRandomAdapted.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.apache.spark.util.random | ||
|
||
import org.apache.commons.math3.random.{RandomGenerator, RandomGeneratorFactory} | ||
|
||
// copied from org.apache.spark.sql.catalyst.expressions.Rand | ||
// adapted to apache commons math3 RandomGenerator | ||
class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) with RandomGenerator { | ||
def this() = this(System.nanoTime) | ||
|
||
private var seed = XORShiftRandom.hashSeed(init) | ||
|
||
override protected def next(bits: Int): Int = { | ||
var nextSeed = seed ^ (seed << 21) | ||
nextSeed ^= (nextSeed >>> 35) | ||
nextSeed ^= (nextSeed << 4) | ||
seed = nextSeed | ||
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int] | ||
} | ||
|
||
override def setSeed(s: Long): Unit = { | ||
seed = XORShiftRandom.hashSeed(s) | ||
} | ||
|
||
override def setSeed(s: Int): Unit = { | ||
seed = XORShiftRandom.hashSeed(s.toLong) | ||
} | ||
|
||
override def setSeed(seed: Array[Int]): Unit = { | ||
this.seed = XORShiftRandom.hashSeed(RandomGeneratorFactory.convertToLong(seed)) | ||
} | ||
} | ||
|
48 changes: 48 additions & 0 deletions
48
src/test/scala/org/apache/spark/sql/daria/functionsTests.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package org.apache.spark.sql.daria | ||
|
||
import com.github.mrpowers.spark.daria.sql.SparkSessionTestWrapper | ||
import com.github.mrpowers.spark.fast.tests.{ColumnComparer, DataFrameComparer} | ||
import org.apache.spark.sql.daria.functions._ | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.functions.stddev | ||
import utest._ | ||
|
||
object functionsTests extends TestSuite with DataFrameComparer with ColumnComparer with SparkSessionTestWrapper { | ||
|
||
val tests = Tests { | ||
'rand_gamma - { | ||
"has correct mean and standard deviation" - { | ||
val sourceDF = spark.range(100000).select(randGamma(2.0, 2.0)) | ||
val stats = sourceDF.agg( | ||
mean("gamma_random").as("mean"), | ||
stddev("gamma_random").as("stddev") | ||
).collect()(0) | ||
|
||
val gammaMean = stats.getAs[Double]("mean") | ||
val gammaStddev = stats.getAs[Double]("stddev") | ||
|
||
// Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0) | ||
assert(gammaMean > 0) | ||
assert(math.abs(gammaMean - 4.0) < 0.5) | ||
assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5) | ||
} | ||
} | ||
|
||
'rand_laplace - { | ||
"has correct mean and standard deviation" - { | ||
val sourceDF = spark.range(100000).select(randLaplace()) | ||
val stats = sourceDF.agg( | ||
mean("laplace_random").as("mean"), | ||
stddev("laplace_random").as("std_dev") | ||
).collect()(0) | ||
|
||
val laplaceMean = stats.getAs[Double]("mean") | ||
val laplaceStdDev = stats.getAs[Double]("std_dev") | ||
|
||
// Laplace distribution with mean=0.0 and scale=1.0 has mean=0.0 and stddev=sqrt(2.0) | ||
assert(math.abs(laplaceMean) <= 0.1) | ||
assert(math.abs(laplaceStdDev - math.sqrt(2.0)) < 0.5) | ||
} | ||
} | ||
} | ||
} |