Skip to content

Commit

Permalink
more efficient moments (#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie authored Jan 12, 2022
1 parent 0861d71 commit d96093b
Showing 1 changed file with 22 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,23 @@ class MomentsMonoid extends Monoid[Moments] with CommutativeMonoid[Moments] {
if (countCombined == 0.0) zero
else {
val delta = b.mean - a.mean
val delta_n = delta / countCombined
val delta_n2 = delta_n * delta_n
val delta_n3 = delta_n2 * delta_n

val meanCombined = Moments.getCombinedMeanDouble(a.m0D, a.mean, b.m0D, b.mean)

val m2 = a.m2 + b.m2 +
math.pow(delta, 2) * a.m0D * b.m0D / countCombined
val m2 = a.m2 + b.m2 + delta * delta_n * a.m0D * b.m0D

val m3 = a.m3 + b.m3 +
math.pow(delta, 3) * a.m0D * b.m0D * (a.m0D - b.m0D) / math.pow(countCombined, 2) +
3 * delta * (a.m0D * b.m2 - b.m0D * a.m2) / countCombined
delta * delta_n2 * a.m0D * b.m0D * (a.m0D - b.m0D) +
3 * delta_n * (a.m0D * b.m2 - b.m0D * a.m2)

val m4 = a.m4 + b.m4 +
math.pow(delta, 4) * a.m0D * b.m0D * (math.pow(a.m0D, 2) -
a.m0D * b.m0D + math.pow(b.m0D, 2)) / math.pow(countCombined, 3) +
6 * math.pow(delta, 2) * (math.pow(a.m0D, 2) * b.m2 +
math.pow(b.m0D, 2) * a.m2) / math.pow(countCombined, 2) +
4 * delta * (a.m0D * b.m3 - b.m0D * a.m3) / countCombined
delta * delta_n3 * a.m0D * b.m0D * (math.pow(a.m0D, 2) -
a.m0D * b.m0D + math.pow(b.m0D, 2)) +
6 * delta_n2 * (math.pow(a.m0D, 2) * b.m2 + math.pow(b.m0D, 2) * a.m2) +
4 * delta_n * (a.m0D * b.m3 - b.m0D * a.m3)

new Moments(countCombined, meanCombined, m2, m3, m4)
}
Expand Down Expand Up @@ -260,21 +262,23 @@ class MomentsMonoid extends Monoid[Moments] with CommutativeMonoid[Moments] {
m4 = 0.0
} else {
val delta = b.mean - mean
val delta_n = delta / countCombined
val delta_n2 = delta_n * delta_n
val delta_n3 = delta_n2 * delta_n

val meanCombined = Moments.getCombinedMeanDouble(count, mean, b.m0D, b.mean)

val m2Combined = m2 + b.m2 +
math.pow(delta, 2) * count * b.m0D / countCombined
val m2Combined = m2 + b.m2 + delta * delta_n * count * b.m0D

val m3Combined = m3 + b.m3 +
math.pow(delta, 3) * count * b.m0D * (count - b.m0D) / math.pow(countCombined, 2) +
3 * delta * (count * b.m2 - b.m0D * m2) / countCombined
delta * delta_n2 * count * b.m0D * (count - b.m0D) +
3 * delta_n * (count * b.m2 - b.m0D * m2)

val m4Combined = m4 + b.m4 +
math.pow(delta, 4) * count * b.m0D * (math.pow(count, 2) -
count * b.m0D + math.pow(b.m0D, 2)) / math.pow(countCombined, 3) +
6 * math.pow(delta, 2) * (math.pow(count, 2) * b.m2 +
math.pow(b.m0D, 2) * m2) / math.pow(countCombined, 2) +
4 * delta * (count * b.m3 - b.m0D * m3) / countCombined
delta * delta_n3 * count * b.m0D * (math.pow(count, 2) -
count * b.m0D + math.pow(b.m0D, 2)) +
6 * delta_n2 * (math.pow(count, 2) * b.m2 + math.pow(b.m0D, 2) * m2) +
4 * delta_n * (count * b.m3 - b.m0D * m3)

mean = meanCombined
m2 = m2Combined
Expand Down

0 comments on commit d96093b

Please sign in to comment.