Skip to content

Commit 45d9f0c

Browse files
committed
For comments.
1 parent d094286 commit 45d9f0c

File tree

3 files changed

+51
-51
lines changed

3 files changed

+51
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,11 @@ import org.apache.spark.sql.types._
3030
*/
3131
abstract class Covariance(
3232
left: Expression,
33-
right: Expression,
34-
mutableAggBufferOffset: Int,
35-
inputAggBufferOffset: Int)
33+
right: Expression)
3634
extends ImperativeAggregate with Serializable {
37-
3835
override def children: Seq[Expression] = Seq(left, right)
3936

40-
override def nullable: Boolean = false
37+
override def nullable: Boolean = true
4138

4239
override def dataType: DataType = DoubleType
4340

@@ -66,20 +63,22 @@ abstract class Covariance(
6663
AttributeReference("count", LongType)())
6764

6865
// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
69-
val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
70-
val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
71-
val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
66+
val xAvgOffset = mutableAggBufferOffset
67+
val yAvgOffset = mutableAggBufferOffset + 1
68+
val CkOffset = mutableAggBufferOffset + 2
69+
val countOffset = mutableAggBufferOffset + 3
7270

7371
// Local cache of inputAggBufferOffset(s) that will be used in update and merge
74-
val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
75-
val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
76-
val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
72+
val inputXAvgOffset = inputAggBufferOffset
73+
val inputYAvgOffset = inputAggBufferOffset + 1
74+
val inputCkOffset = inputAggBufferOffset + 2
75+
val inputCountOffset = inputAggBufferOffset + 3
7776

7877
override def initialize(buffer: MutableRow): Unit = {
79-
buffer.setDouble(mutableAggBufferOffset, 0.0)
80-
buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
81-
buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
82-
buffer.setLong(mutableAggBufferOffsetPlus3, 0L)
78+
buffer.setDouble(xAvgOffset, 0.0)
79+
buffer.setDouble(yAvgOffset, 0.0)
80+
buffer.setDouble(CkOffset, 0.0)
81+
buffer.setLong(countOffset, 0L)
8382
}
8483

8584
override def update(buffer: MutableRow, input: InternalRow): Unit = {
@@ -90,10 +89,10 @@ abstract class Covariance(
9089
val x = leftEval.asInstanceOf[Double]
9190
val y = rightEval.asInstanceOf[Double]
9291

93-
var xAvg = buffer.getDouble(mutableAggBufferOffset)
94-
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
95-
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
96-
var count = buffer.getLong(mutableAggBufferOffsetPlus3)
92+
var xAvg = buffer.getDouble(xAvgOffset)
93+
var yAvg = buffer.getDouble(yAvgOffset)
94+
var Ck = buffer.getDouble(CkOffset)
95+
var count = buffer.getLong(countOffset)
9796

9897
val deltaX = x - xAvg
9998
val deltaY = y - yAvg
@@ -102,30 +101,30 @@ abstract class Covariance(
102101
yAvg += deltaY / count
103102
Ck += deltaX * (y - yAvg)
104103

105-
buffer.setDouble(mutableAggBufferOffset, xAvg)
106-
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
107-
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
108-
buffer.setLong(mutableAggBufferOffsetPlus3, count)
104+
buffer.setDouble(xAvgOffset, xAvg)
105+
buffer.setDouble(yAvgOffset, yAvg)
106+
buffer.setDouble(CkOffset, Ck)
107+
buffer.setLong(countOffset, count)
109108
}
110109
}
111110

112111
// Merge counters from other partitions. Formula can be found at:
113112
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
114113
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
115-
val count2 = buffer2.getLong(inputAggBufferOffsetPlus3)
114+
val count2 = buffer2.getLong(inputCountOffset)
116115

117116
// We only go to merge two buffers if there is at least one record aggregated in buffer2.
118117
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
119118
// is more than zero too, then we won't get a divide by zero exception.
120119
if (count2 > 0) {
121-
var xAvg = buffer1.getDouble(mutableAggBufferOffset)
122-
var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
123-
var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
124-
var count = buffer1.getLong(mutableAggBufferOffsetPlus3)
120+
var xAvg = buffer1.getDouble(xAvgOffset)
121+
var yAvg = buffer1.getDouble(yAvgOffset)
122+
var Ck = buffer1.getDouble(CkOffset)
123+
var count = buffer1.getLong(countOffset)
125124

126-
val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
127-
val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
128-
val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
125+
val xAvg2 = buffer2.getDouble(inputXAvgOffset)
126+
val yAvg2 = buffer2.getDouble(inputYAvgOffset)
127+
val Ck2 = buffer2.getDouble(inputCkOffset)
129128

130129
val totalCount = count + count2
131130
val deltaX = xAvg - xAvg2
@@ -135,10 +134,10 @@ abstract class Covariance(
135134
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
136135
count = totalCount
137136

138-
buffer1.setDouble(mutableAggBufferOffset, xAvg)
139-
buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
140-
buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
141-
buffer1.setLong(mutableAggBufferOffsetPlus3, count)
137+
buffer1.setDouble(xAvgOffset, xAvg)
138+
buffer1.setDouble(yAvgOffset, yAvg)
139+
buffer1.setDouble(CkOffset, Ck)
140+
buffer1.setLong(countOffset, count)
142141
}
143142
}
144143
}
@@ -148,10 +147,7 @@ case class CovSample(
148147
right: Expression,
149148
mutableAggBufferOffset: Int = 0,
150149
inputAggBufferOffset: Int = 0)
151-
extends Covariance(left, right, mutableAggBufferOffset, inputAggBufferOffset) {
152-
153-
def this(left: Expression, right: Expression) =
154-
this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
150+
extends Covariance(left, right) {
155151

156152
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
157153
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -160,10 +156,10 @@ case class CovSample(
160156
copy(inputAggBufferOffset = newInputAggBufferOffset)
161157

162158
override def eval(buffer: InternalRow): Any = {
163-
val count = buffer.getLong(mutableAggBufferOffsetPlus3)
159+
val count = buffer.getLong(countOffset)
164160
if (count > 0) {
165161
if (count > 1) {
166-
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
162+
val Ck = buffer.getDouble(CkOffset)
167163
val cov = Ck / (count - 1)
168164
if (cov.isNaN) {
169165
null
@@ -184,10 +180,7 @@ case class CovPopulation(
184180
right: Expression,
185181
mutableAggBufferOffset: Int = 0,
186182
inputAggBufferOffset: Int = 0)
187-
extends Covariance(left, right, mutableAggBufferOffset, inputAggBufferOffset) {
188-
189-
def this(left: Expression, right: Expression) =
190-
this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
183+
extends Covariance(left, right) {
191184

192185
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
193186
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -196,9 +189,9 @@ case class CovPopulation(
196189
copy(inputAggBufferOffset = newInputAggBufferOffset)
197190

198191
override def eval(buffer: InternalRow): Any = {
199-
val count = buffer.getLong(mutableAggBufferOffsetPlus3)
192+
val count = buffer.getLong(countOffset)
200193
if (count > 0) {
201-
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
194+
val Ck = buffer.getDouble(CkOffset)
202195
val cov = Ck / count
203196
if (cov.isNaN) {
204197
null

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ object functions extends LegacyFunctions {
312312
* Aggregate function: returns the population covariance for two columns.
313313
*
314314
* @group agg_funcs
315-
* @since 1.6.0
315+
* @since 2.0.0
316316
*/
317317
def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction {
318318
CovPopulation(column1.expr, column2.expr)
@@ -322,7 +322,7 @@ object functions extends LegacyFunctions {
322322
* Aggregate function: returns the population covariance for two columns.
323323
*
324324
* @group agg_funcs
325-
* @since 1.6.0
325+
* @since 2.0.0
326326
*/
327327
def covar_pop(columnName1: String, columnName2: String): Column = {
328328
covar_pop(Column(columnName1), Column(columnName2))
@@ -332,7 +332,7 @@ object functions extends LegacyFunctions {
332332
* Aggregate function: returns the sample covariance for two columns.
333333
*
334334
* @group agg_funcs
335-
* @since 1.6.0
335+
* @since 2.0.0
336336
*/
337337
def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction {
338338
CovSample(column1.expr, column2.expr)
@@ -342,7 +342,7 @@ object functions extends LegacyFunctions {
342342
* Aggregate function: returns the sample covariance for two columns.
343343
*
344344
* @group agg_funcs
345-
* @since 1.6.0
345+
* @since 2.0.0
346346
*/
347347
def covar_samp(columnName1: String, columnName2: String): Column = {
348348
covar_samp(Column(columnName1), Column(columnName2))

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
822822

823823
val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
824824
assert(math.abs(cov_pop - 565.25) < 1e-12)
825+
826+
val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
827+
val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0)
828+
assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
829+
830+
val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
831+
assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
825832
}
826833

827834
test("no aggregation function (SPARK-11486)") {

0 commit comments

Comments
 (0)