Skip to content

Commit 63eee86

Browse files
viiryadavies
authored andcommitted
[SPARK-9297] [SQL] Add covar_pop and covar_samp
JIRA: https://issues.apache.org/jira/browse/SPARK-9297 Add two aggregation functions: covar_pop and covar_samp. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Closes #10029 from viirya/covar-funcs.
1 parent d6fd9b3 commit 63eee86

File tree

4 files changed

+272
-0
lines changed

4 files changed

+272
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ object FunctionRegistry {
182182
expression[Average]("avg"),
183183
expression[Corr]("corr"),
184184
expression[Count]("count"),
185+
expression[CovPopulation]("covar_pop"),
186+
expression[CovSample]("covar_samp"),
185187
expression[First]("first"),
186188
expression[First]("first_value"),
187189
expression[Last]("last"),
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.util.TypeUtils
24+
import org.apache.spark.sql.types._
25+
26+
/**
27+
* Compute the covariance between two expressions.
28+
* When applied on empty data (i.e., count is zero), it returns NULL.
29+
*
30+
*/
31+
abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate
32+
with Serializable {
33+
override def children: Seq[Expression] = Seq(left, right)
34+
35+
override def nullable: Boolean = true
36+
37+
override def dataType: DataType = DoubleType
38+
39+
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
40+
41+
override def checkInputDataTypes(): TypeCheckResult = {
42+
if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
43+
TypeCheckResult.TypeCheckSuccess
44+
} else {
45+
TypeCheckResult.TypeCheckFailure(
46+
s"covariance requires that both arguments are double type, " +
47+
s"not (${left.dataType}, ${right.dataType}).")
48+
}
49+
}
50+
51+
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
52+
53+
override def inputAggBufferAttributes: Seq[AttributeReference] = {
54+
aggBufferAttributes.map(_.newInstance())
55+
}
56+
57+
override val aggBufferAttributes: Seq[AttributeReference] = Seq(
58+
AttributeReference("xAvg", DoubleType)(),
59+
AttributeReference("yAvg", DoubleType)(),
60+
AttributeReference("Ck", DoubleType)(),
61+
AttributeReference("count", LongType)())
62+
63+
// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
64+
val xAvgOffset = mutableAggBufferOffset
65+
val yAvgOffset = mutableAggBufferOffset + 1
66+
val CkOffset = mutableAggBufferOffset + 2
67+
val countOffset = mutableAggBufferOffset + 3
68+
69+
// Local cache of inputAggBufferOffset(s) that will be used in update and merge
70+
val inputXAvgOffset = inputAggBufferOffset
71+
val inputYAvgOffset = inputAggBufferOffset + 1
72+
val inputCkOffset = inputAggBufferOffset + 2
73+
val inputCountOffset = inputAggBufferOffset + 3
74+
75+
override def initialize(buffer: MutableRow): Unit = {
76+
buffer.setDouble(xAvgOffset, 0.0)
77+
buffer.setDouble(yAvgOffset, 0.0)
78+
buffer.setDouble(CkOffset, 0.0)
79+
buffer.setLong(countOffset, 0L)
80+
}
81+
82+
override def update(buffer: MutableRow, input: InternalRow): Unit = {
83+
val leftEval = left.eval(input)
84+
val rightEval = right.eval(input)
85+
86+
if (leftEval != null && rightEval != null) {
87+
val x = leftEval.asInstanceOf[Double]
88+
val y = rightEval.asInstanceOf[Double]
89+
90+
var xAvg = buffer.getDouble(xAvgOffset)
91+
var yAvg = buffer.getDouble(yAvgOffset)
92+
var Ck = buffer.getDouble(CkOffset)
93+
var count = buffer.getLong(countOffset)
94+
95+
val deltaX = x - xAvg
96+
val deltaY = y - yAvg
97+
count += 1
98+
xAvg += deltaX / count
99+
yAvg += deltaY / count
100+
Ck += deltaX * (y - yAvg)
101+
102+
buffer.setDouble(xAvgOffset, xAvg)
103+
buffer.setDouble(yAvgOffset, yAvg)
104+
buffer.setDouble(CkOffset, Ck)
105+
buffer.setLong(countOffset, count)
106+
}
107+
}
108+
109+
// Merge counters from other partitions. Formula can be found at:
110+
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
111+
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
112+
val count2 = buffer2.getLong(inputCountOffset)
113+
114+
// We only go to merge two buffers if there is at least one record aggregated in buffer2.
115+
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
116+
// is more than zero too, then we won't get a divide by zero exception.
117+
if (count2 > 0) {
118+
var xAvg = buffer1.getDouble(xAvgOffset)
119+
var yAvg = buffer1.getDouble(yAvgOffset)
120+
var Ck = buffer1.getDouble(CkOffset)
121+
var count = buffer1.getLong(countOffset)
122+
123+
val xAvg2 = buffer2.getDouble(inputXAvgOffset)
124+
val yAvg2 = buffer2.getDouble(inputYAvgOffset)
125+
val Ck2 = buffer2.getDouble(inputCkOffset)
126+
127+
val totalCount = count + count2
128+
val deltaX = xAvg - xAvg2
129+
val deltaY = yAvg - yAvg2
130+
Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
131+
xAvg = (xAvg * count + xAvg2 * count2) / totalCount
132+
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
133+
count = totalCount
134+
135+
buffer1.setDouble(xAvgOffset, xAvg)
136+
buffer1.setDouble(yAvgOffset, yAvg)
137+
buffer1.setDouble(CkOffset, Ck)
138+
buffer1.setLong(countOffset, count)
139+
}
140+
}
141+
}
142+
143+
case class CovSample(
144+
left: Expression,
145+
right: Expression,
146+
mutableAggBufferOffset: Int = 0,
147+
inputAggBufferOffset: Int = 0)
148+
extends Covariance(left, right) {
149+
150+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
151+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
152+
153+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
154+
copy(inputAggBufferOffset = newInputAggBufferOffset)
155+
156+
override def eval(buffer: InternalRow): Any = {
157+
val count = buffer.getLong(countOffset)
158+
if (count > 1) {
159+
val Ck = buffer.getDouble(CkOffset)
160+
val cov = Ck / (count - 1)
161+
if (cov.isNaN) {
162+
null
163+
} else {
164+
cov
165+
}
166+
} else {
167+
null
168+
}
169+
}
170+
}
171+
172+
case class CovPopulation(
173+
left: Expression,
174+
right: Expression,
175+
mutableAggBufferOffset: Int = 0,
176+
inputAggBufferOffset: Int = 0)
177+
extends Covariance(left, right) {
178+
179+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
180+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
181+
182+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
183+
copy(inputAggBufferOffset = newInputAggBufferOffset)
184+
185+
override def eval(buffer: InternalRow): Any = {
186+
val count = buffer.getLong(countOffset)
187+
if (count > 0) {
188+
val Ck = buffer.getDouble(CkOffset)
189+
if (Ck.isNaN) {
190+
null
191+
} else {
192+
Ck / count
193+
}
194+
} else {
195+
null
196+
}
197+
}
198+
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,46 @@ object functions extends LegacyFunctions {
308308
def countDistinct(columnName: String, columnNames: String*): Column =
309309
countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
310310

311+
/**
312+
* Aggregate function: returns the population covariance for two columns.
313+
*
314+
* @group agg_funcs
315+
* @since 2.0.0
316+
*/
317+
def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction {
318+
CovPopulation(column1.expr, column2.expr)
319+
}
320+
321+
/**
322+
* Aggregate function: returns the population covariance for two columns.
323+
*
324+
* @group agg_funcs
325+
* @since 2.0.0
326+
*/
327+
def covar_pop(columnName1: String, columnName2: String): Column = {
328+
covar_pop(Column(columnName1), Column(columnName2))
329+
}
330+
331+
/**
332+
* Aggregate function: returns the sample covariance for two columns.
333+
*
334+
* @group agg_funcs
335+
* @since 2.0.0
336+
*/
337+
def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction {
338+
CovSample(column1.expr, column2.expr)
339+
}
340+
341+
/**
342+
* Aggregate function: returns the sample covariance for two columns.
343+
*
344+
* @group agg_funcs
345+
* @since 2.0.0
346+
*/
347+
def covar_samp(columnName1: String, columnName2: String): Column = {
348+
covar_samp(Column(columnName1), Column(columnName2))
349+
}
350+
311351
/**
312352
* Aggregate function: returns the first value in a group.
313353
*

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
807807
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
808808
}
809809

810+
test("covariance: covar_pop and covar_samp") {
811+
// non-trivial example. To reproduce in python, use:
812+
// >>> import numpy as np
813+
// >>> a = np.array(range(20))
814+
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
815+
// >>> np.cov(a, b, bias = 0)[0][1]
816+
// 595.0
817+
// >>> np.cov(a, b, bias = 1)[0][1]
818+
// 565.25
819+
val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
820+
val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
821+
assert(math.abs(cov_samp - 595.0) < 1e-12)
822+
823+
val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
824+
assert(math.abs(cov_pop - 565.25) < 1e-12)
825+
826+
val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
827+
val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
828+
assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
829+
830+
val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
831+
assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
832+
833+
// one row test
834+
val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
835+
val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
836+
assert(cov_samp3 == null)
837+
838+
val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
839+
assert(cov_pop3 == 0.0)
840+
}
841+
810842
test("no aggregation function (SPARK-11486)") {
811843
val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
812844
.groupBy("s").count()

0 commit comments

Comments
 (0)