Skip to content

Commit 1fca1eb

Browse files
authored
feat(sql): weighted average and standard deviation (#6457)
1 parent 360948a commit 1fca1eb

25 files changed

+5129
-3056
lines changed

core/src/main/java/io/questdb/griffin/ExpressionParser.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -888,14 +888,13 @@ void parseExpr(
888888

889889
lexer.backTo(lastPos + SqlKeywords.CAST_KEYWORD_LENGTH, castTok);
890890
tok = castTok;
891-
if (prevBranch != BRANCH_DOT_DEREFERENCE) {
892-
scopeStack.push(Scope.CAST);
893-
thisBranch = BRANCH_OPERATOR;
894-
opStack.push(expressionNodePool.next().of(ExpressionNode.LITERAL, "cast", Integer.MIN_VALUE, lastPos));
895-
break;
896-
} else {
891+
if (prevBranch == BRANCH_DOT_DEREFERENCE || isCompletedOperand(prevBranch)) {
897892
throw SqlException.$(lastPos, "'cast' is not allowed here");
898893
}
894+
scopeStack.push(Scope.CAST);
895+
thisBranch = BRANCH_OPERATOR;
896+
opStack.push(expressionNodePool.next().of(ExpressionNode.LITERAL, "cast", Integer.MIN_VALUE, lastPos));
897+
break;
899898
}
900899
processDefaultBranch = true;
901900
break;

core/src/main/java/io/questdb/griffin/engine/functions/GroupByFunction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ default void setDouble(MapValue mapValue, double value) {
190190
throw new UnsupportedOperationException();
191191
}
192192

193+
// used by generated code
193194
default void setEmpty(MapValue value) {
194195
setNull(value);
195196
}

core/src/main/java/io/questdb/griffin/engine/functions/groupby/AbstractStdDevGroupByFunction.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,33 @@ public boolean isConstant() {
9797
return false;
9898
}
9999

100+
// Chan et al. [CGL82; CGL83]
101+
@Override
102+
public void merge(MapValue destValue, MapValue srcValue) {
103+
double srcMean = srcValue.getDouble(valueIndex);
104+
double srcSum = srcValue.getDouble(valueIndex + 1);
105+
long srcCount = srcValue.getLong(valueIndex + 2);
106+
107+
double destMean = destValue.getDouble(valueIndex);
108+
double destSum = destValue.getDouble(valueIndex + 1);
109+
long destCount = destValue.getLong(valueIndex + 2);
110+
111+
long mergedCount = srcCount + destCount;
112+
double delta = destMean - srcMean;
113+
114+
// This is only valid when countA is much larger than countB.
115+
// If both are large and similar sizes, delta is not scaled down.
116+
// double mergedMean = srcMean + delta * ((double) destCount / mergedCount);
117+
118+
// So we use this instead:
119+
double mergedMean = (srcCount * srcMean + destCount * destMean) / mergedCount;
120+
double mergedSum = srcSum + destSum + (delta * delta) * ((double) (srcCount * destCount) / mergedCount);
121+
122+
destValue.putDouble(valueIndex, mergedMean);
123+
destValue.putDouble(valueIndex + 1, mergedSum);
124+
destValue.putLong(valueIndex + 2, mergedCount);
125+
}
126+
100127
@Override
101128
public void setDouble(MapValue mapValue, double value) {
102129
mapValue.putDouble(valueIndex, value);
@@ -112,7 +139,7 @@ public void setNull(MapValue mapValue) {
112139

113140
@Override
114141
public boolean supportsParallelism() {
115-
return false;
142+
return UnaryFunction.super.supportsParallelism();
116143
}
117144

118145
protected void aggregate(MapValue mapValue, double value) {
@@ -128,4 +155,3 @@ protected void aggregate(MapValue mapValue, double value) {
128155
mapValue.addLong(valueIndex + 2, 1L);
129156
}
130157
}
131-
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
/*******************************************************************************
2+
* ___ _ ____ ____
3+
* / _ \ _ _ ___ ___| |_| _ \| __ )
4+
* | | | | | | |/ _ \/ __| __| | | | _ \
5+
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
6+
* \__\_\\__,_|\___||___/\__|____/|____/
7+
*
8+
* Copyright (c) 2014-2019 Appsicle
9+
* Copyright (c) 2019-2024 QuestDB
10+
*
11+
* Licensed under the Apache License, Version 2.0 (the "License");
12+
* you may not use this file except in compliance with the License.
13+
* You may obtain a copy of the License at
14+
*
15+
* http://www.apache.org/licenses/LICENSE-2.0
16+
*
17+
* Unless required by applicable law or agreed to in writing, software
18+
* distributed under the License is distributed on an "AS IS" BASIS,
19+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20+
* See the License for the specific language governing permissions and
21+
* limitations under the License.
22+
*
23+
******************************************************************************/
24+
25+
package io.questdb.griffin.engine.functions.groupby;
26+
27+
import io.questdb.cairo.ArrayColumnTypes;
28+
import io.questdb.cairo.ColumnType;
29+
import io.questdb.cairo.map.MapValue;
30+
import io.questdb.cairo.sql.Function;
31+
import io.questdb.cairo.sql.Record;
32+
import io.questdb.griffin.engine.functions.BinaryFunction;
33+
import io.questdb.griffin.engine.functions.DoubleFunction;
34+
import io.questdb.griffin.engine.functions.GroupByFunction;
35+
import io.questdb.std.Numbers;
36+
import org.jetbrains.annotations.NotNull;
37+
38+
/**
39+
* Base class to compute the unbiased weighted standard deviation.
40+
* <p>
41+
* According to
42+
* <a href="https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance">
43+
* Wikipedia</a>, there are two variants for the unbiased estimator of population variance
44+
* based on a subset of it (a sample): one for frequency weights, and another for reliability
45+
* weights.
46+
* <p>
47+
* A frequency weight represents the number of occurrences of the sample in the dataset.
48+
* A reliability weight represents the "importance" or "trustworthiness" of the given sample.
49+
* <p>
50+
* We implement both functions, called <code>weighted_stddev_rel</code> for reliability weights,
51+
* and <code>weighted_stddev_freq</code> for frequency weights. We also define the shorthand
52+
* <code>weighted_stddev</code> for <code>weighted_stddev_rel</code>.
53+
* <p>
54+
* These two:
55+
* <pre>
56+
* SELECT weighted_stddev_rel(value, weight) FROM my_table;
57+
* SELECT weighted_stddev(value, weight) FROM my_table;
58+
* </pre>
59+
* calculate the equivalent of
60+
* <pre>
61+
* SELECT sqrt(
62+
* (
63+
* sum(weight * value * value)
64+
* - (sum(weight * value) * sum(weight * value) / sum(weight))
65+
* )
66+
* / (sum(weight) - sum(weight * weight) / sum(weight))
67+
* ) FROM my_table;
68+
* </pre>
69+
* <p>
70+
* This one:
71+
* <pre>
72+
* SELECT weighted_stddev_freq(value, weight) FROM my_table;
73+
* </pre>
74+
* calculates the equivalent of
75+
* <pre>
76+
* SELECT sqrt(
77+
* (
78+
* sum(weight * value * value)
79+
* - (sum(weight * value) * sum(weight * value) / sum(weight))
80+
* )
81+
* / (sum(weight) - 1)
82+
* ) FROM my_table;
83+
* </pre>
84+
*
85+
* @see <a href="https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm">
86+
* Weighted incremental algorithm
87+
* </a>
88+
* @see <a href="https://github.com/mraad/spark-stat/blob/master/src/main/scala/com/esri/spark/WeightedStatCounter.scala">
89+
* Merge function for weighted incremental algorithm
90+
* </a>
91+
*/
92+
public abstract class AbstractWeightedStdDevGroupByFunction extends DoubleFunction implements GroupByFunction, BinaryFunction {
93+
protected final Function sampleArg;
94+
private final Function weightArg;
95+
protected int valueIndex;
96+
97+
protected AbstractWeightedStdDevGroupByFunction(@NotNull Function sampleArg, @NotNull Function weightArg) {
98+
this.sampleArg = sampleArg;
99+
this.weightArg = weightArg;
100+
}
101+
102+
@Override
103+
public void computeFirst(MapValue mapValue, Record record, long rowId) {
104+
final double sample = sampleArg.getDouble(record);
105+
final double weight = weightArg.getDouble(record);
106+
if (!Numbers.isFinite(sample) || !Numbers.isFinite(weight) || weight == 0.0) {
107+
mapValue.putDouble(valueIndex, 0.0); // w_sum
108+
mapValue.putDouble(valueIndex + 1, 0.0); // w_sum2
109+
mapValue.putDouble(valueIndex + 2, 0.0); // mean
110+
mapValue.putDouble(valueIndex + 3, 0.0); // S
111+
return;
112+
}
113+
mapValue.putDouble(valueIndex, weight); // w_sum
114+
mapValue.putDouble(valueIndex + 1, weight * weight); // w_sum2
115+
mapValue.putDouble(valueIndex + 2, sample); // mean
116+
mapValue.putDouble(valueIndex + 3, 0.0); // S
117+
}
118+
119+
@Override
120+
public void computeNext(MapValue mapValue, Record record, long rowId) {
121+
// Acquire values from record
122+
final double sample = sampleArg.getDouble(record);
123+
final double weight = weightArg.getDouble(record);
124+
if (!Numbers.isFinite(sample) || !Numbers.isFinite(weight) || weight == 0.0) {
125+
return;
126+
}
127+
// Acquire current computation state
128+
double wSum = mapValue.getDouble(valueIndex);
129+
double wSum2 = mapValue.getDouble(valueIndex + 1);
130+
double mean = mapValue.getDouble(valueIndex + 2);
131+
double s = mapValue.getDouble(valueIndex + 3);
132+
133+
// Update computation state with values from record
134+
wSum += weight;
135+
wSum2 += weight * weight;
136+
double meanOld = mean;
137+
mean += (weight / wSum) * (sample - meanOld);
138+
s += weight * (sample - meanOld) * (sample - mean);
139+
140+
// Store updated computation state
141+
mapValue.putDouble(valueIndex, wSum);
142+
mapValue.putDouble(valueIndex + 1, wSum2);
143+
mapValue.putDouble(valueIndex + 2, mean);
144+
mapValue.putDouble(valueIndex + 3, s);
145+
}
146+
147+
@Override
148+
public Function getLeft() {
149+
return sampleArg;
150+
}
151+
152+
@Override
153+
public Function getRight() {
154+
return weightArg;
155+
}
156+
157+
@Override
158+
public int getSampleByFlags() {
159+
return GroupByFunction.SAMPLE_BY_FILL_ALL;
160+
}
161+
162+
@Override
163+
public int getValueIndex() {
164+
return valueIndex;
165+
}
166+
167+
@Override
168+
public void initValueIndex(int valueIndex) {
169+
this.valueIndex = valueIndex;
170+
}
171+
172+
@Override
173+
public void initValueTypes(ArrayColumnTypes columnTypes) {
174+
this.valueIndex = columnTypes.getColumnCount();
175+
columnTypes.add(ColumnType.DOUBLE);
176+
columnTypes.add(ColumnType.DOUBLE);
177+
columnTypes.add(ColumnType.DOUBLE);
178+
columnTypes.add(ColumnType.DOUBLE);
179+
}
180+
181+
@Override
182+
public boolean isConstant() {
183+
return false;
184+
}
185+
186+
@Override
187+
public boolean isThreadSafe() {
188+
return BinaryFunction.super.isThreadSafe();
189+
}
190+
191+
@Override
192+
public void merge(MapValue destValue, MapValue srcValue) {
193+
// Acquire source computation state
194+
double srcWsum = srcValue.getDouble(valueIndex);
195+
double srcWsum2 = srcValue.getDouble(valueIndex + 1);
196+
double srcMean = srcValue.getDouble(valueIndex + 2);
197+
double srcS = srcValue.getDouble(valueIndex + 3);
198+
199+
if (srcWsum == 0.0) {
200+
// srcValue has no data -- return with destValue untouched
201+
return;
202+
}
203+
204+
// Acquire destination computation state
205+
double destWsum = destValue.getDouble(valueIndex);
206+
double destWsum2 = destValue.getDouble(valueIndex + 1);
207+
double destMean = destValue.getDouble(valueIndex + 2);
208+
double destS = destValue.getDouble(valueIndex + 3);
209+
210+
if (destWsum == 0.0) {
211+
// srcValue has data, destValue doesn't. Copy entire srcValue to destValue.
212+
destValue.putDouble(valueIndex, srcWsum);
213+
destValue.putDouble(valueIndex + 1, srcWsum2);
214+
destValue.putDouble(valueIndex + 2, srcMean);
215+
destValue.putDouble(valueIndex + 3, srcS);
216+
return;
217+
}
218+
// Both srcValue and destValue have data -- merge them
219+
220+
// Compute interim results
221+
double meanDelta = srcMean - destMean;
222+
223+
// Compute merged computation state
224+
double mergedWsum = srcWsum + destWsum;
225+
double mergedWsum2 = srcWsum2 + destWsum2;
226+
double mergedMean = (srcWsum * srcMean + destWsum * destMean) / mergedWsum;
227+
double mergedS = srcS + destS + (srcWsum * meanDelta) / mergedWsum * (destWsum * meanDelta);
228+
229+
// Store merged computation state to destination
230+
destValue.putDouble(valueIndex, mergedWsum);
231+
destValue.putDouble(valueIndex + 1, mergedWsum2);
232+
destValue.putDouble(valueIndex + 2, mergedMean);
233+
destValue.putDouble(valueIndex + 3, mergedS);
234+
}
235+
236+
@Override
237+
public void setDouble(MapValue mapValue, double value) {
238+
// We set the state such that getDouble() in both Frequency and Reliability
239+
// subclasses end up returning `value`
240+
mapValue.putDouble(valueIndex, 2.0); // wSum
241+
mapValue.putDouble(valueIndex + 1, 2.0); // wSum2
242+
mapValue.putDouble(valueIndex + 2, Double.NaN); // mean
243+
mapValue.putDouble(valueIndex + 3, value * value); // S
244+
}
245+
246+
@Override
247+
public void setNull(MapValue mapValue) {
248+
mapValue.putDouble(valueIndex, Double.NaN);
249+
mapValue.putDouble(valueIndex + 1, Double.NaN);
250+
mapValue.putDouble(valueIndex + 2, Double.NaN);
251+
mapValue.putDouble(valueIndex + 3, Double.NaN);
252+
}
253+
254+
@Override
255+
public boolean supportsParallelism() {
256+
return BinaryFunction.super.supportsParallelism();
257+
}
258+
}

core/src/main/java/io/questdb/griffin/engine/functions/groupby/StdDevPopGroupByFunctionFactory.java

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
package io.questdb.griffin.engine.functions.groupby;
2626

2727
import io.questdb.cairo.CairoConfiguration;
28-
import io.questdb.cairo.map.MapValue;
2928
import io.questdb.cairo.sql.Function;
3029
import io.questdb.cairo.sql.Record;
3130
import io.questdb.griffin.FunctionFactory;
@@ -77,37 +76,5 @@ public double getDouble(Record rec) {
7776
public String getName() {
7877
return "stddev_pop";
7978
}
80-
81-
// Chan et al. [CGL82; CGL83]
82-
@Override
83-
public void merge(MapValue destValue, MapValue srcValue) {
84-
double srcMean = srcValue.getDouble(valueIndex);
85-
double srcSum = srcValue.getDouble(valueIndex + 1);
86-
long srcCount = srcValue.getLong(valueIndex + 2);
87-
88-
double destMean = destValue.getDouble(valueIndex);
89-
double destSum = destValue.getDouble(valueIndex + 1);
90-
long destCount = destValue.getLong(valueIndex + 2);
91-
92-
long mergedCount = srcCount + destCount;
93-
double delta = destMean - srcMean;
94-
95-
// This is only valid when countA is much larger than countB.
96-
// If both are large and similar sizes, delta is not scaled down.
97-
// double mergedMean = srcMean + delta * ((double) destCount / mergedCount);
98-
99-
// So we use this instead:
100-
double mergedMean = (srcCount * srcMean + destCount * destMean) / mergedCount;
101-
double mergedSum = srcSum + destSum + (delta * delta) * ((double) (srcCount * destCount) / mergedCount);
102-
103-
destValue.putDouble(valueIndex, mergedMean);
104-
destValue.putDouble(valueIndex + 1, mergedSum);
105-
destValue.putLong(valueIndex + 2, mergedCount);
106-
}
107-
108-
@Override
109-
public boolean supportsParallelism() {
110-
return true;
111-
}
11279
}
11380
}

0 commit comments

Comments
 (0)