Skip to content

Commit

Permalink
Add randLaplace and native randGamma (#156)
Browse files Browse the repository at this point in the history
* Init RandGamma

* add randGamma col function

* add test for gamma and laplace distribution column function

* add shape and scale to flat argument
  • Loading branch information
zeotuan authored Sep 22, 2024
1 parent c16ff4f commit 3630fed
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ val sparkVersion = "3.2.1"
libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided"
libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "1.1.0" % "test"
libraryDependencies += "com.lihaoyi" %% "utest" % "0.7.11" % "test"
libraryDependencies += "com.lihaoyi" %% "utest" % "0.7.11" % "test"
libraryDependencies += "com.lihaoyi" %% "os-lib" % "0.8.0" % "test"
testFrameworks += new TestFramework("com.github.mrpowers.spark.daria.CustomFramework")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -750,5 +750,4 @@ object functions {
def excelEpochToDate(colName: String): Column = {
excelEpochToDate(col(colName))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandomAdapted

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) extends TernaryExpression
with ExpectsInputTypes
with Stateful
with ExpressionWithRandomSeed {

override def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

override protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}
@transient private var distribution: GammaDistribution = _

def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

override def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

override protected def evalInternal(input: InternalRow): Double = distribution.sample()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);")
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""",
isNull = FalseLiteral)
}

override def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

override def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

override def dataType: DataType = DoubleType

override def first: Expression = child

override def second: Expression = shape

override def third: Expression = scale

override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma = RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))
}
26 changes: 26 additions & 0 deletions src/main/scala/org/apache/spark/sql/daria/functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.apache.spark.sql.daria

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Expression, RandGamma}
import org.apache.spark.sql.functions.{lit, log, rand, when}
import org.apache.spark.util.Utils

object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
val beta_ = lit(beta)
val u = rand(seed)
when(u < 0.5, mu_ + beta_ * log(lit(2) * u))
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u)))
.alias("laplace_random")
}

def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta)
def randLaplace(): Column = randLaplace(0.0, 1.0)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.apache.spark.util.random

import org.apache.commons.math3.random.{RandomGenerator, RandomGeneratorFactory}

// copied from org.apache.spark.sql.catalyst.expressions.Rand
// adapted to apache commons math3 RandomGenerator
class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) with RandomGenerator {
def this() = this(System.nanoTime)

private var seed = XORShiftRandom.hashSeed(init)

override protected def next(bits: Int): Int = {
var nextSeed = seed ^ (seed << 21)
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}

override def setSeed(s: Long): Unit = {
seed = XORShiftRandom.hashSeed(s)
}

override def setSeed(s: Int): Unit = {
seed = XORShiftRandom.hashSeed(s.toLong)
}

override def setSeed(seed: Array[Int]): Unit = {
this.seed = XORShiftRandom.hashSeed(RandomGeneratorFactory.convertToLong(seed))
}
}

48 changes: 48 additions & 0 deletions src/test/scala/org/apache/spark/sql/daria/functionsTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.apache.spark.sql.daria

import com.github.mrpowers.spark.daria.sql.SparkSessionTestWrapper
import com.github.mrpowers.spark.fast.tests.{ColumnComparer, DataFrameComparer}
import org.apache.spark.sql.daria.functions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.stddev
import utest._

object functionsTests extends TestSuite with DataFrameComparer with ColumnComparer with SparkSessionTestWrapper {

val tests = Tests {
'rand_gamma - {
"has correct mean and standard deviation" - {
val sourceDF = spark.range(100000).select(randGamma(2.0, 2.0))
val stats = sourceDF.agg(
mean("gamma_random").as("mean"),
stddev("gamma_random").as("stddev")
).collect()(0)

val gammaMean = stats.getAs[Double]("mean")
val gammaStddev = stats.getAs[Double]("stddev")

// Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0)
assert(gammaMean > 0)
assert(math.abs(gammaMean - 4.0) < 0.5)
assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5)
}
}

'rand_laplace - {
"has correct mean and standard deviation" - {
val sourceDF = spark.range(100000).select(randLaplace())
val stats = sourceDF.agg(
mean("laplace_random").as("mean"),
stddev("laplace_random").as("std_dev")
).collect()(0)

val laplaceMean = stats.getAs[Double]("mean")
val laplaceStdDev = stats.getAs[Double]("std_dev")

// Laplace distribution with mean=0.0 and scale=1.0 has mean=0.0 and stddev=sqrt(2.0)
assert(math.abs(laplaceMean) <= 0.1)
assert(math.abs(laplaceStdDev - math.sqrt(2.0)) < 0.5)
}
}
}
}

0 comments on commit 3630fed

Please sign in to comment.