@@ -30,14 +30,11 @@ import org.apache.spark.sql.types._
30
30
*/
31
31
abstract class Covariance (
32
32
left : Expression ,
33
- right : Expression ,
34
- mutableAggBufferOffset : Int ,
35
- inputAggBufferOffset : Int )
33
+ right : Expression )
36
34
extends ImperativeAggregate with Serializable {
37
-
38
35
override def children : Seq [Expression ] = Seq (left, right)
39
36
40
- override def nullable : Boolean = false
37
+ override def nullable : Boolean = true
41
38
42
39
override def dataType : DataType = DoubleType
43
40
@@ -66,20 +63,22 @@ abstract class Covariance(
66
63
AttributeReference (" count" , LongType )())
67
64
68
65
// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
69
- val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
70
- val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
71
- val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
66
+ val xAvgOffset = mutableAggBufferOffset
67
+ val yAvgOffset = mutableAggBufferOffset + 1
68
+ val CkOffset = mutableAggBufferOffset + 2
69
+ val countOffset = mutableAggBufferOffset + 3
72
70
73
71
// Local cache of inputAggBufferOffset(s) that will be used in update and merge
74
- val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
75
- val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
76
- val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
72
+ val inputXAvgOffset = inputAggBufferOffset
73
+ val inputYAvgOffset = inputAggBufferOffset + 1
74
+ val inputCkOffset = inputAggBufferOffset + 2
75
+ val inputCountOffset = inputAggBufferOffset + 3
77
76
78
77
override def initialize (buffer : MutableRow ): Unit = {
79
- buffer.setDouble(mutableAggBufferOffset , 0.0 )
80
- buffer.setDouble(mutableAggBufferOffsetPlus1 , 0.0 )
81
- buffer.setDouble(mutableAggBufferOffsetPlus2 , 0.0 )
82
- buffer.setLong(mutableAggBufferOffsetPlus3 , 0L )
78
+ buffer.setDouble(xAvgOffset , 0.0 )
79
+ buffer.setDouble(yAvgOffset , 0.0 )
80
+ buffer.setDouble(CkOffset , 0.0 )
81
+ buffer.setLong(countOffset , 0L )
83
82
}
84
83
85
84
override def update (buffer : MutableRow , input : InternalRow ): Unit = {
@@ -90,10 +89,10 @@ abstract class Covariance(
90
89
val x = leftEval.asInstanceOf [Double ]
91
90
val y = rightEval.asInstanceOf [Double ]
92
91
93
- var xAvg = buffer.getDouble(mutableAggBufferOffset )
94
- var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1 )
95
- var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2 )
96
- var count = buffer.getLong(mutableAggBufferOffsetPlus3 )
92
+ var xAvg = buffer.getDouble(xAvgOffset )
93
+ var yAvg = buffer.getDouble(yAvgOffset )
94
+ var Ck = buffer.getDouble(CkOffset )
95
+ var count = buffer.getLong(countOffset )
97
96
98
97
val deltaX = x - xAvg
99
98
val deltaY = y - yAvg
@@ -102,30 +101,30 @@ abstract class Covariance(
102
101
yAvg += deltaY / count
103
102
Ck += deltaX * (y - yAvg)
104
103
105
- buffer.setDouble(mutableAggBufferOffset , xAvg)
106
- buffer.setDouble(mutableAggBufferOffsetPlus1 , yAvg)
107
- buffer.setDouble(mutableAggBufferOffsetPlus2 , Ck )
108
- buffer.setLong(mutableAggBufferOffsetPlus3 , count)
104
+ buffer.setDouble(xAvgOffset , xAvg)
105
+ buffer.setDouble(yAvgOffset , yAvg)
106
+ buffer.setDouble(CkOffset , Ck )
107
+ buffer.setLong(countOffset , count)
109
108
}
110
109
}
111
110
112
111
// Merge counters from other partitions. Formula can be found at:
113
112
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
114
113
override def merge (buffer1 : MutableRow , buffer2 : InternalRow ): Unit = {
115
- val count2 = buffer2.getLong(inputAggBufferOffsetPlus3 )
114
+ val count2 = buffer2.getLong(inputCountOffset )
116
115
117
116
// We only go to merge two buffers if there is at least one record aggregated in buffer2.
118
117
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
119
118
// is more than zero too, then we won't get a divide by zero exception.
120
119
if (count2 > 0 ) {
121
- var xAvg = buffer1.getDouble(mutableAggBufferOffset )
122
- var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1 )
123
- var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2 )
124
- var count = buffer1.getLong(mutableAggBufferOffsetPlus3 )
120
+ var xAvg = buffer1.getDouble(xAvgOffset )
121
+ var yAvg = buffer1.getDouble(yAvgOffset )
122
+ var Ck = buffer1.getDouble(CkOffset )
123
+ var count = buffer1.getLong(countOffset )
125
124
126
- val xAvg2 = buffer2.getDouble(inputAggBufferOffset )
127
- val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1 )
128
- val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2 )
125
+ val xAvg2 = buffer2.getDouble(inputXAvgOffset )
126
+ val yAvg2 = buffer2.getDouble(inputYAvgOffset )
127
+ val Ck2 = buffer2.getDouble(inputCkOffset )
129
128
130
129
val totalCount = count + count2
131
130
val deltaX = xAvg - xAvg2
@@ -135,10 +134,10 @@ abstract class Covariance(
135
134
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
136
135
count = totalCount
137
136
138
- buffer1.setDouble(mutableAggBufferOffset , xAvg)
139
- buffer1.setDouble(mutableAggBufferOffsetPlus1 , yAvg)
140
- buffer1.setDouble(mutableAggBufferOffsetPlus2 , Ck )
141
- buffer1.setLong(mutableAggBufferOffsetPlus3 , count)
137
+ buffer1.setDouble(xAvgOffset , xAvg)
138
+ buffer1.setDouble(yAvgOffset , yAvg)
139
+ buffer1.setDouble(CkOffset , Ck )
140
+ buffer1.setLong(countOffset , count)
142
141
}
143
142
}
144
143
}
@@ -148,10 +147,7 @@ case class CovSample(
148
147
right : Expression ,
149
148
mutableAggBufferOffset : Int = 0 ,
150
149
inputAggBufferOffset : Int = 0 )
151
- extends Covariance (left, right, mutableAggBufferOffset, inputAggBufferOffset) {
152
-
153
- def this (left : Expression , right : Expression ) =
154
- this (left, right, mutableAggBufferOffset = 0 , inputAggBufferOffset = 0 )
150
+ extends Covariance (left, right) {
155
151
156
152
override def withNewMutableAggBufferOffset (newMutableAggBufferOffset : Int ): ImperativeAggregate =
157
153
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -160,10 +156,10 @@ case class CovSample(
160
156
copy(inputAggBufferOffset = newInputAggBufferOffset)
161
157
162
158
override def eval (buffer : InternalRow ): Any = {
163
- val count = buffer.getLong(mutableAggBufferOffsetPlus3 )
159
+ val count = buffer.getLong(countOffset )
164
160
if (count > 0 ) {
165
161
if (count > 1 ) {
166
- val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2 )
162
+ val Ck = buffer.getDouble(CkOffset )
167
163
val cov = Ck / (count - 1 )
168
164
if (cov.isNaN) {
169
165
null
@@ -184,10 +180,7 @@ case class CovPopulation(
184
180
right : Expression ,
185
181
mutableAggBufferOffset : Int = 0 ,
186
182
inputAggBufferOffset : Int = 0 )
187
- extends Covariance (left, right, mutableAggBufferOffset, inputAggBufferOffset) {
188
-
189
- def this (left : Expression , right : Expression ) =
190
- this (left, right, mutableAggBufferOffset = 0 , inputAggBufferOffset = 0 )
183
+ extends Covariance (left, right) {
191
184
192
185
override def withNewMutableAggBufferOffset (newMutableAggBufferOffset : Int ): ImperativeAggregate =
193
186
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -196,9 +189,9 @@ case class CovPopulation(
196
189
copy(inputAggBufferOffset = newInputAggBufferOffset)
197
190
198
191
override def eval (buffer : InternalRow ): Any = {
199
- val count = buffer.getLong(mutableAggBufferOffsetPlus3 )
192
+ val count = buffer.getLong(countOffset )
200
193
if (count > 0 ) {
201
- val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2 )
194
+ val Ck = buffer.getDouble(CkOffset )
202
195
val cov = Ck / count
203
196
if (cov.isNaN) {
204
197
null
0 commit comments