Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl stddev and variance function in SQL and PPL #115

Merged
merged 12 commits into from
Jun 11, 2021
Merged
1 change: 1 addition & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dependencies {
compile group: 'org.springframework', name: 'spring-beans', version: '5.2.5.RELEASE'
compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
compile group: 'com.facebook.presto', name: 'presto-matching', version: '0.240'
compile group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
compile project(':common')

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ public Expression visitNot(Not node, AnalysisContext context) {

@Override
public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) {
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
Optional<BuiltinFunctionName> builtinFunctionName =
BuiltinFunctionName.ofAggregation(node.getFuncName());
if (builtinFunctionName.isPresent()) {
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,22 @@ public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator varSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARSAMP, expressions);
}

public Aggregator varPop(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARPOP, expressions);
}

public Aggregator stddevSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions);
}

public Aggregator stddevPop(Expression... expressions) {
return aggregate(BuiltinFunctionName.STDDEV_POP, expressions);
}

public RankingWindowFunction rowNumber() {
return (RankingWindowFunction) repository.compile(
BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.TIME;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation;
import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample;
import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation;
import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample;

import com.google.common.collect.ImmutableMap;
import java.util.Collections;
Expand Down Expand Up @@ -68,6 +72,10 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(count());
repository.register(min());
repository.register(max());
repository.register(varSamp());
repository.register(varPop());
repository.register(stddevSamp());
repository.register(stddevPop());
}

private static FunctionResolver avg() {
Expand Down Expand Up @@ -159,4 +167,48 @@ private static FunctionResolver max() {
.build()
);
}

private static FunctionResolver varSamp() {
FunctionName functionName = BuiltinFunctionName.VARSAMP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> varianceSample(arguments, DOUBLE))
.build()
);
}

private static FunctionResolver varPop() {
FunctionName functionName = BuiltinFunctionName.VARPOP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> variancePopulation(arguments, DOUBLE))
.build()
);
}

private static FunctionResolver stddevSamp() {
FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> stddevSample(arguments, DOUBLE))
.build()
);
}

private static FunctionResolver stddevPop() {
FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> stddevPopulation(arguments, DOUBLE))
.build()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package org.opensearch.sql.expression.aggregation;

import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue;
import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/**
* StandardDeviation Aggregator.
*/
public class StdDevAggregator extends Aggregator<StdDevAggregator.StdDevState> {

private final boolean isSampleStdDev;

/**
* Build Population Variance {@link VarianceAggregator}.
*/
public static Aggregator stddevPopulation(List<Expression> arguments,
ExprCoreType returnType) {
return new StdDevAggregator(false, arguments, returnType);
}

/**
* Build Sample Variance {@link VarianceAggregator}.
*/
public static Aggregator stddevSample(List<Expression> arguments,
ExprCoreType returnType) {
return new StdDevAggregator(true, arguments, returnType);
}

/**
* VarianceAggregator constructor.
*
* @param isSampleStdDev true for sample standard deviation aggregator, false for population
* standard deviation aggregator.
* @param arguments aggregator arguments.
* @param returnType aggregator return types.
*/
public StdDevAggregator(
Boolean isSampleStdDev, List<Expression> arguments, ExprCoreType returnType) {
super(
isSampleStdDev
? BuiltinFunctionName.STDDEV_SAMP.getName()
: BuiltinFunctionName.STDDEV_POP.getName(),
arguments,
returnType);
this.isSampleStdDev = isSampleStdDev;
}

@Override
public StdDevAggregator.StdDevState create() {
return new StdDevAggregator.StdDevState(isSampleStdDev);
}

@Override
protected StdDevAggregator.StdDevState iterate(ExprValue value,
StdDevAggregator.StdDevState state) {
state.evaluate(value);
return state;
}

@Override
public String toString() {
return StringUtils.format(
"%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments()));
}

protected static class StdDevState implements AggregationState {

private final StandardDeviation standardDeviation;

private final List<Double> values = new ArrayList<>();

public StdDevState(boolean isSampleStdDev) {
this.standardDeviation = new StandardDeviation(isSampleStdDev);
}

public void evaluate(ExprValue value) {
values.add(value.doubleValue());
}

@Override
public ExprValue result() {
return values.size() == 0
? ExprNullValue.of()
: doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package org.opensearch.sql.expression.aggregation;

import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue;
import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/**
* Variance Aggregator.
*/
public class VarianceAggregator extends Aggregator<VarianceAggregator.VarianceState> {

private final boolean isSampleVariance;

/**
* Build Population Variance {@link VarianceAggregator}.
*/
public static Aggregator variancePopulation(List<Expression> arguments,
ExprCoreType returnType) {
return new VarianceAggregator(false, arguments, returnType);
}

/**
* Build Sample Variance {@link VarianceAggregator}.
*/
public static Aggregator varianceSample(List<Expression> arguments,
ExprCoreType returnType) {
return new VarianceAggregator(true, arguments, returnType);
}

/**
* VarianceAggregator constructor.
*
* @param isSampleVariance true for sample variance aggregator, false for population variance
* aggregator.
* @param arguments aggregator arguments.
* @param returnType aggregator return types.
*/
public VarianceAggregator(
Boolean isSampleVariance, List<Expression> arguments, ExprCoreType returnType) {
super(
isSampleVariance
? BuiltinFunctionName.VARSAMP.getName()
: BuiltinFunctionName.VARPOP.getName(),
arguments,
returnType);
this.isSampleVariance = isSampleVariance;
}

@Override
public VarianceState create() {
return new VarianceState(isSampleVariance);
}

@Override
protected VarianceState iterate(ExprValue value, VarianceState state) {
state.evaluate(value);
return state;
}

@Override
public String toString() {
return StringUtils.format(
"%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments()));
}

protected static class VarianceState implements AggregationState {

private final Variance variance;

private final List<Double> values = new ArrayList<>();

public VarianceState(boolean isSampleVariance) {
this.variance = new Variance(isSampleVariance);
}

public void evaluate(ExprValue value) {
values.add(value.doubleValue());
}

@Override
public ExprValue result() {
return values.size() == 0
? ExprNullValue.of()
: doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.sql.expression.function;

import com.google.common.collect.ImmutableMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import lombok.Getter;
Expand Down Expand Up @@ -126,6 +127,14 @@ public enum BuiltinFunctionName {
COUNT(FunctionName.of("count")),
MIN(FunctionName.of("min")),
MAX(FunctionName.of("max")),
// sample variance
VARSAMP(FunctionName.of("var_samp")),
// population standard variance
VARPOP(FunctionName.of("var_pop")),
// sample standard deviation.
STDDEV_SAMP(FunctionName.of("stddev_samp")),
// population standard deviation.
STDDEV_POP(FunctionName.of("stddev_pop")),

/**
* Text Functions.
Expand Down Expand Up @@ -189,7 +198,28 @@ public enum BuiltinFunctionName {
ALL_NATIVE_FUNCTIONS = builder.build();
}

private static final Map<String, BuiltinFunctionName> AGGREGATION_FUNC_MAPPING =
new ImmutableMap.Builder<String, BuiltinFunctionName>()
.put("max", BuiltinFunctionName.MAX)
.put("min", BuiltinFunctionName.MIN)
.put("avg", BuiltinFunctionName.AVG)
.put("count", BuiltinFunctionName.COUNT)
.put("sum", BuiltinFunctionName.SUM)
.put("var_pop", BuiltinFunctionName.VARPOP)
.put("var_samp", BuiltinFunctionName.VARSAMP)
.put("variance", BuiltinFunctionName.VARPOP)
.put("std", BuiltinFunctionName.STDDEV_POP)
.put("stddev", BuiltinFunctionName.STDDEV_POP)
.put("stddev_pop", BuiltinFunctionName.STDDEV_POP)
.put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP)
.build();

public static Optional<BuiltinFunctionName> of(String str) {
return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null));
}

public static Optional<BuiltinFunctionName> ofAggregation(String functionName) {
return Optional.ofNullable(
AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ public void aggregation_filter() {
);
}

@Test
public void variance_mapto_varPop() {
assertAnalyzeEqual(
dsl.varPop(DSL.ref("integer_value", INTEGER)),
AstDSL.aggregate("variance", qualifiedName("integer_value"))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Loading