Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutable MomentsState, more efficient + for double added to Moments #1050

Merged
merged 11 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 145 additions & 58 deletions algebird-core/src/main/scala/com/twitter/algebird/MomentsGroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,55 @@ sealed class Moments(val m0D: Double, val m1: Double, val m2: Double, val m3: Do
math.sqrt(m0D) * m3 / math.pow(m2, 1.5)
sritchie marked this conversation as resolved.
Show resolved Hide resolved

def kurtosis: Double =
m0D * m4 / math.pow(m2, 2) - 3
m0D * m4 / (m2 * m2) - 3

/**
* Combines this instance with another [[Moments]] instance.
* @param b
* the other instance
* @return
* a [[Moments]] instances representing the combined moments of this
* instance and and `b`
*/
def +(b: Moments): Moments = Moments.momentsMonoid.plus(this, b)

/**
* Returns a new [[Moments]] instance generated by merging in the new
* observation `b`.
* @param b
* a new observation
* @return
* a [[Moments]] instance representing the combined moments of this
* instance and and `b`.
*/
def +(b: Double): Moments = {
val n = m0D + 1
val delta = b - mean
val delta_n = delta / n
val delta_n2 = delta_n * delta_n
val term1 = delta * delta_n * m0D

val meanCombined = Moments.getCombinedMeanDouble(m0D, mean, 1.0, b)
val m2combined = m2 + term1
val m3combined = m3 + term1 * delta_n * (n - 2) - 3 * delta_n * m2
val m4combined = m4 + term1 * delta_n2 * (n * n - 3 * n + 3) +
6 * delta_n2 * m2 - 4 * delta_n * m3

new Moments(n, meanCombined, m2combined, m3combined, m4combined)
}

/**
* Returns a [[Fold]] instance that uses `+` to accumulate deltas into this
* [[Moments]] instance.
*/
def fold: Fold[Double, Moments] =
Copy link
Collaborator Author

@sritchie sritchie Jan 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the style we used in ExpHist, where the fold is an instance method that lets you fold into THIS Moments, implemented by

_ => Moments.MomentsState.fromMoments(this)

Let me know if you want to just have an empty fold on Moments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh... looking at it now, it seems a bit strange to me, but I can see the benefit of being able to start with a non empty value to resume some updates.

Fold.foldMutable[Moments.MomentsState, Double, Moments](
{ case (state, x) =>
state += x
},
_ => Moments.MomentsState.newEmpty,
(state: Moments.MomentsState) => this + state.toMoments
)

override def productArity: Int = 5
override def productElement(idx: Int): Any =
Expand Down Expand Up @@ -102,6 +150,92 @@ sealed class Moments(val m0D: Double, val m1: Double, val m2: Double, val m3: Do
}

object Moments {
final class MomentsState(
var count: Double,
var mean: Double,
var m2: Double,
var m3: Double,
var m4: Double) {

def +=(b: Moments): this.type = {
/*
* Unfortunately we copy the code from the monoid's plus implementation,
* but we do it to avoid allocating a new Moments on every item in the
* loop. the Monoid laws test that sum matches looping on plus
*/
val countCombined = count + b.m0D
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could use a variant of Kahan summation on countCombined, right

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, and if we had the fold version, or KahanState, it would fit nicely here.


if (countCombined == 0.0) {
mean = 0.0
m2 = 0.0
m3 = 0.0
m4 = 0.0
} else {
val delta = b.mean - mean
val delta_n = delta / countCombined
val delta_n2 = delta_n * delta_n
val delta_n3 = delta_n2 * delta_n
val count_sq = count * count
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also killed the math.pow calls everywhere.

val rn_sq = b.m0D * b.m0D

val meanCombined = Moments.getCombinedMeanDouble(count, mean, b.m0D, b.mean)

val m2Combined = m2 + b.m2 + delta * delta_n * count * b.m0D

val m3Combined = m3 + b.m3 +
delta * delta_n2 * count * b.m0D * (count - b.m0D) +
3 * delta_n * (count * b.m2 - b.m0D * m2)

val m4Combined = m4 + b.m4 +
delta * delta_n3 * count * b.m0D *
(count_sq - count * b.m0D + rn_sq) +
6 * delta_n2 * (count_sq * b.m2 + rn_sq * m2) +
4 * delta_n * (count * b.m3 - b.m0D * m3)

mean = meanCombined
m2 = m2Combined
m3 = m3Combined
m4 = m4Combined
}

count = countCombined
this
}

def +=(b: Double): this.type = {
val prevCount = count
count += 1

val delta = b - mean
val delta_n = delta / count
val delta_n2 = delta_n * delta_n
val term1 = delta * delta_n * prevCount

mean = Moments.getCombinedMeanDouble(prevCount, mean, 1.0, b)
m4 += term1 * delta_n2 * (count * count - 3 * count + 3) +
6 * delta_n2 * m2 - 4 * delta_n * m3
m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2
m2 += term1
this
}

def toMoments: Moments = new Moments(count, mean, m2, m3, m4)

def resetFromMoments(m: Moments): this.type = {
count = m.m0D
mean = m.m1
m2 = m.m2
m3 = m.m3
m4 = m.m4
this
}
}

object MomentsState {
def newEmpty: MomentsState =
sritchie marked this conversation as resolved.
Show resolved Hide resolved
new MomentsState(0.0, 0.0, 0.0, 0.0, 0.0)
}

@deprecated("use monoid[Moments], this isn't lawful for negate", "0.13.8")
def group: Group[Moments] with CommutativeGroup[Moments] =
MomentsGroup
Expand All @@ -111,6 +245,8 @@ object Moments {

val aggregator: MomentsAggregator.type = MomentsAggregator

val fold: Fold[Double, Moments] = momentsMonoid.zero.fold
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we go, the zero fold.


def numericAggregator[N](implicit num: Numeric[N]): MonoidAggregator[N, Moments, Moments] =
Aggregator.prepareMonoid { n: N => Moments(num.toDouble(n)) }

Expand Down Expand Up @@ -211,6 +347,8 @@ class MomentsMonoid extends Monoid[Moments] with CommutativeMonoid[Moments] {
val delta_n = delta / countCombined
val delta_n2 = delta_n * delta_n
val delta_n3 = delta_n2 * delta_n
val ln_sq = a.m0D * a.m0D
val rn_sq = b.m0D * b.m0D

val meanCombined = Moments.getCombinedMeanDouble(a.m0D, a.mean, b.m0D, b.mean)

Expand All @@ -221,9 +359,9 @@ class MomentsMonoid extends Monoid[Moments] with CommutativeMonoid[Moments] {
3 * delta_n * (a.m0D * b.m2 - b.m0D * a.m2)

val m4 = a.m4 + b.m4 +
delta * delta_n3 * a.m0D * b.m0D * (math.pow(a.m0D, 2) -
a.m0D * b.m0D + math.pow(b.m0D, 2)) +
6 * delta_n2 * (math.pow(a.m0D, 2) * b.m2 + math.pow(b.m0D, 2) * a.m2) +
delta * delta_n3 * a.m0D * b.m0D * (ln_sq -
a.m0D * b.m0D + rn_sq) +
6 * delta_n2 * (ln_sq * b.m2 + rn_sq * a.m2) +
4 * delta_n * (a.m0D * b.m3 - b.m0D * a.m3)

new Moments(countCombined, meanCombined, m2, m3, m4)
Expand All @@ -233,63 +371,12 @@ class MomentsMonoid extends Monoid[Moments] with CommutativeMonoid[Moments] {
override def sumOption(items: TraversableOnce[Moments]): Option[Moments] =
if (items.isEmpty) None
else {
val state = Moments.MomentsState.newEmpty
val iter = items.toIterator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this implementation lives in MomentsState now.

val init = iter.next()

var count: Double = init.m0D
var mean: Double = init.mean
var m2: Double = init.m2
var m3: Double = init.m3
var m4: Double = init.m4

while (iter.hasNext) {

/*
* Unfortunately we copy the code in plus, but we do
* it to avoid allocating a new Moments on every item
* in the loop. the Monoid laws test that sum
* matches looping on plus
*/
val b = iter.next()

val countCombined = count + b.m0D

if (countCombined == 0.0) {
mean = 0.0
m2 = 0.0
m3 = 0.0
m4 = 0.0
} else {
val delta = b.mean - mean
val delta_n = delta / countCombined
val delta_n2 = delta_n * delta_n
val delta_n3 = delta_n2 * delta_n

val meanCombined = Moments.getCombinedMeanDouble(count, mean, b.m0D, b.mean)

val m2Combined = m2 + b.m2 + delta * delta_n * count * b.m0D

val m3Combined = m3 + b.m3 +
delta * delta_n2 * count * b.m0D * (count - b.m0D) +
3 * delta_n * (count * b.m2 - b.m0D * m2)

val m4Combined = m4 + b.m4 +
delta * delta_n3 * count * b.m0D * (math.pow(count, 2) -
count * b.m0D + math.pow(b.m0D, 2)) +
6 * delta_n2 * (math.pow(count, 2) * b.m2 + math.pow(b.m0D, 2) * m2) +
4 * delta_n * (count * b.m3 - b.m0D * m3)

mean = meanCombined
m2 = m2Combined
m3 = m3Combined
m4 = m4Combined
}

count = countCombined
state += iter.next()
}

Some(new Moments(count, mean, m2, m3, m4))
Some(state.toMoments)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class MomentsLaws extends CheckProperties {
val recur = Gen.lzy(opBasedGen[A](genA))
val pair = Gen.zip(recur, recur)

import Operators.Ops

Gen.frequency(
(10, init),
(1, pair.map { case (a, b) => a + b })
Expand Down Expand Up @@ -56,8 +54,9 @@ class MomentsLaws extends CheckProperties {
}

property("scaling by a and b is the same as scaling by a*b; similarly for addition") {
// use Int here instead of doubles so that we don't have to worry about overlfowing to Infinity and having to
// fine-tune numerical precision thresholds.
// use Int here instead of doubles so that we don't have to worry about
// overflowing to Infinity and having to fine-tune numerical precision
// thresholds.
forAll(opGen, Gen.choose(0, Int.MaxValue), Gen.choose(0, Int.MaxValue)) { (mom, a0, b0) =>
val a = a0 & Int.MaxValue
val b = b0 & Int.MaxValue
Expand All @@ -75,6 +74,39 @@ class MomentsLaws extends CheckProperties {
}
}

property("adding double matches adding singleton Moments instance") {
forAll(opGen, Gen.choose(0, Int.MaxValue)) { (mom, x) =>
val plusMoments = mom + Moments(x)
val plusDouble = mom + x
equiv.equiv(plusMoments, plusDouble)
}
}

property("adding doubles via +, fold, aggregator should match") {
forAll(opGen, Gen.containerOf[Seq, Double](Gen.choose(0, 1000))) {
(mom, xs) =>
val fullViaAdd = xs.foldLeft(mom)(_ + _)
val fullViaFold = mom.fold.overTraversable(xs)
val fullViaAgg = mom + MomentsAggregator(xs)

equiv.equiv(fullViaAdd, fullViaFold)
equiv.equiv(fullViaAdd, fullViaAgg)
}
}

property("adding Moment instances via +, sumOption should match") {
forAll(opGen, Gen.containerOf[Seq, Double](Gen.choose(0, 1000))) {
(mom, ints) =>
val xs = ints.map(Moments(_)).toTraversable
val monoid = Moments.momentsMonoid

val fullViaAdd = xs.foldLeft(mom)(_ + _)
val fullViaMonoid = mom + monoid.sumOption(xs).getOrElse(monoid.zero)

equiv.equiv(fullViaAdd, fullViaMonoid)
}
}

property("scaling does affect total weight, doesn't affect mean, variance, or moments") {
// def sign(x: Int): Int = if (x < 0) -1 else 1
forAll(opGen, Gen.choose(0, Int.MaxValue)) { (mom, a0) =>
Expand All @@ -100,7 +132,7 @@ class MomentsTest extends AnyWordSpec with Matchers {
* Given a list of doubles, create a Moments object to hold the list's central moments.
*/
def getMoments(xs: List[Double]): Moments =
MomentsAggregator(xs)
Moments.aggregator(xs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched this so we get coverage on the aggregator line :)


"Moments should count" in {
val m1 = getMoments(List(1, 2, 3, 4, 5))
Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ val sharedSettings = Seq(
Nil
}
},
javacOptions ++= Seq("-target", "1.6", "-source", "1.6"),
javacOptions ++= Seq("-target", "1.8", "-source", "1.8"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the build change.

libraryDependencies ++= Seq(
"junit" % "junit" % "4.13.2" % Test,
"com.github.sbt" % "junit-interface" % "0.13.3" % Test
Expand Down