Skip to content

Commit 5a419f6

Browse files
committed
Add covar_pop and covar_samp.
1 parent c793d2d commit 5a419f6

File tree

4 files changed

+271
-0
lines changed

4 files changed

+271
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ object FunctionRegistry {
182182
expression[Average]("avg"),
183183
expression[Corr]("corr"),
184184
expression[Count]("count"),
185+
expression[CovPopulation]("covar_pop"),
186+
expression[CovSample]("covar_samp"),
185187
expression[First]("first"),
186188
expression[First]("first_value"),
187189
expression[Last]("last"),
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.util.TypeUtils
24+
import org.apache.spark.sql.types._
25+
26+
/**
27+
* Compute the covariance between two expressions.
28+
* When applied on empty data (i.e., count is zero), it returns NULL.
29+
*
30+
*/
31+
abstract class Covariance(
32+
left: Expression,
33+
right: Expression,
34+
mutableAggBufferOffset: Int,
35+
inputAggBufferOffset: Int)
36+
extends ImperativeAggregate with Serializable {
37+
38+
override def children: Seq[Expression] = Seq(left, right)
39+
40+
override def nullable: Boolean = false
41+
42+
override def dataType: DataType = DoubleType
43+
44+
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
45+
46+
override def checkInputDataTypes(): TypeCheckResult = {
47+
if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
48+
TypeCheckResult.TypeCheckSuccess
49+
} else {
50+
TypeCheckResult.TypeCheckFailure(
51+
s"covariance requires that both arguments are double type, " +
52+
s"not (${left.dataType}, ${right.dataType}).")
53+
}
54+
}
55+
56+
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
57+
58+
override def inputAggBufferAttributes: Seq[AttributeReference] = {
59+
aggBufferAttributes.map(_.newInstance())
60+
}
61+
62+
override val aggBufferAttributes: Seq[AttributeReference] = Seq(
63+
AttributeReference("xAvg", DoubleType)(),
64+
AttributeReference("yAvg", DoubleType)(),
65+
AttributeReference("Ck", DoubleType)(),
66+
AttributeReference("count", LongType)())
67+
68+
// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
69+
val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
70+
val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
71+
val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
72+
73+
// Local cache of inputAggBufferOffset(s) that will be used in update and merge
74+
val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
75+
val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
76+
val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
77+
78+
override def initialize(buffer: MutableRow): Unit = {
79+
buffer.setDouble(mutableAggBufferOffset, 0.0)
80+
buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
81+
buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
82+
buffer.setLong(mutableAggBufferOffsetPlus3, 0L)
83+
}
84+
85+
override def update(buffer: MutableRow, input: InternalRow): Unit = {
86+
val leftEval = left.eval(input)
87+
val rightEval = right.eval(input)
88+
89+
if (leftEval != null && rightEval != null) {
90+
val x = leftEval.asInstanceOf[Double]
91+
val y = rightEval.asInstanceOf[Double]
92+
93+
var xAvg = buffer.getDouble(mutableAggBufferOffset)
94+
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
95+
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
96+
var count = buffer.getLong(mutableAggBufferOffsetPlus3)
97+
98+
val deltaX = x - xAvg
99+
val deltaY = y - yAvg
100+
count += 1
101+
xAvg += deltaX / count
102+
yAvg += deltaY / count
103+
Ck += deltaX * (y - yAvg)
104+
105+
buffer.setDouble(mutableAggBufferOffset, xAvg)
106+
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
107+
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
108+
buffer.setLong(mutableAggBufferOffsetPlus3, count)
109+
}
110+
}
111+
112+
// Merge counters from other partitions. Formula can be found at:
113+
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
114+
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
115+
val count2 = buffer2.getLong(inputAggBufferOffsetPlus3)
116+
117+
// We only go to merge two buffers if there is at least one record aggregated in buffer2.
118+
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
119+
// is more than zero too, then we won't get a divide by zero exception.
120+
if (count2 > 0) {
121+
var xAvg = buffer1.getDouble(mutableAggBufferOffset)
122+
var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
123+
var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
124+
var count = buffer1.getLong(mutableAggBufferOffsetPlus3)
125+
126+
val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
127+
val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
128+
val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
129+
130+
val totalCount = count + count2
131+
val deltaX = xAvg - xAvg2
132+
val deltaY = yAvg - yAvg2
133+
Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
134+
xAvg = (xAvg * count + xAvg2 * count2) / totalCount
135+
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
136+
count = totalCount
137+
138+
buffer1.setDouble(mutableAggBufferOffset, xAvg)
139+
buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
140+
buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
141+
buffer1.setLong(mutableAggBufferOffsetPlus3, count)
142+
}
143+
}
144+
}
145+
146+
case class CovSample(
147+
left: Expression,
148+
right: Expression,
149+
mutableAggBufferOffset: Int = 0,
150+
inputAggBufferOffset: Int = 0)
151+
extends Covariance(left, right, mutableAggBufferOffset, inputAggBufferOffset) {
152+
153+
def this(left: Expression, right: Expression) =
154+
this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
155+
156+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
157+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
158+
159+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
160+
copy(inputAggBufferOffset = newInputAggBufferOffset)
161+
162+
override def eval(buffer: InternalRow): Any = {
163+
val count = buffer.getLong(mutableAggBufferOffsetPlus3)
164+
if (count > 0) {
165+
if (count > 1) {
166+
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
167+
val cov = Ck / (count - 1)
168+
if (cov.isNaN) {
169+
null
170+
} else {
171+
cov
172+
}
173+
} else {
174+
0.0
175+
}
176+
} else {
177+
null
178+
}
179+
}
180+
}
181+
182+
case class CovPopulation(
183+
left: Expression,
184+
right: Expression,
185+
mutableAggBufferOffset: Int = 0,
186+
inputAggBufferOffset: Int = 0)
187+
extends Covariance(left, right, mutableAggBufferOffset, inputAggBufferOffset) {
188+
189+
def this(left: Expression, right: Expression) =
190+
this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
191+
192+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
193+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
194+
195+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
196+
copy(inputAggBufferOffset = newInputAggBufferOffset)
197+
198+
override def eval(buffer: InternalRow): Any = {
199+
val count = buffer.getLong(mutableAggBufferOffsetPlus3)
200+
if (count > 0) {
201+
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
202+
val cov = Ck / count
203+
if (cov.isNaN) {
204+
null
205+
} else {
206+
cov
207+
}
208+
} else {
209+
null
210+
}
211+
}
212+
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,46 @@ object functions extends LegacyFunctions {
308308
def countDistinct(columnName: String, columnNames: String*): Column =
309309
countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
310310

311+
/**
312+
* Aggregate function: returns the population covariance for two columns.
313+
*
314+
* @group agg_funcs
315+
* @since 1.6.0
316+
*/
317+
def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction {
318+
CovPopulation(column1.expr, column2.expr)
319+
}
320+
321+
/**
322+
* Aggregate function: returns the population covariance for two columns.
323+
*
324+
* @group agg_funcs
325+
* @since 1.6.0
326+
*/
327+
def covar_pop(columnName1: String, columnName2: String): Column = {
328+
covar_pop(Column(columnName1), Column(columnName2))
329+
}
330+
331+
/**
332+
* Aggregate function: returns the sample covariance for two columns.
333+
*
334+
* @group agg_funcs
335+
* @since 1.6.0
336+
*/
337+
def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction {
338+
CovSample(column1.expr, column2.expr)
339+
}
340+
341+
/**
342+
* Aggregate function: returns the sample covariance for two columns.
343+
*
344+
* @group agg_funcs
345+
* @since 1.6.0
346+
*/
347+
def covar_samp(columnName1: String, columnName2: String): Column = {
348+
covar_samp(Column(columnName1), Column(columnName2))
349+
}
350+
311351
/**
312352
* Aggregate function: returns the first value in a group.
313353
*

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
788788
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
789789
}
790790

791+
test("covariance: covar_pop and covar_samp") {
792+
// non-trivial example. To reproduce in python, use:
793+
// >>> import numpy as np
794+
// >>> a = np.array(range(20))
795+
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
796+
// >>> np.cov(a, b, bias = 0)[0][1]
797+
// 595.0
798+
// >>> np.cov(a, b, bias = 1)[0][1]
799+
// 565.25
800+
val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
801+
val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
802+
assert(math.abs(cov_samp - 595.0) < 1e-12)
803+
804+
val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
805+
assert(math.abs(cov_pop - 565.25) < 1e-12)
806+
}
807+
791808
test("no aggregation function (SPARK-11486)") {
792809
val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
793810
.groupBy("s").count()

0 commit comments

Comments
 (0)